diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md new file mode 100644 index 0000000000000000000000000000000000000000..7b391279e479ade4ed5327728f19be8752e11507 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md @@ -0,0 +1,24 @@ +--- +name: TensorFlow Lite Op Request +about: Use this template for reporting ops you are using or missing. + +--- + + +**System information** +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- TensorFlow installed from (source or binary): +- TensorFlow version (or github SHA if from source): + + +**Provide the text output from tflite_convert** + +``` +# Copy and paste here +``` + +Also, please include a link to a GraphDef or the model if possible. + +**Any other info / logs** + +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.gitignore b/.gitignore index 57d84228cfd037325716b5faa56c17f7424fe713..90324058600bee46af56e49028977971848a80de 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,7 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata -/tensorflow/lite/downloads/** +/tensorflow/lite/tools/make/downloads/** /tensorflow/lite/gen/** /tensorflow/lite/examples/ios/simple/data/*.txt /tensorflow/lite/examples/ios/simple/data/*.tflite diff --git a/CODEOWNERS b/CODEOWNERS index 54a61a4d72c40d297d90d53e223f64f813d9167d..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,7 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq -/tensorflow/core/nccl/ @azaks @csigg +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar @@ -51,13 +51,13 @@ /tensorflow/contrib/pi_examples/ @maciekcc /tensorflow/contrib/quantization/ @petewarden /tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie /tensorflow/contrib/seq2seq/ @ebrevdo @lmthang /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh /tensorflow/contrib/slim/ @sguada @thenbasilmanran /tensorflow/contrib/stateless/ @girving @alextp /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey +/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 # NEED OWNER: /tensorflow/contrib/testing/ /tensorflow/contrib/timeseries/ @allenlavoie /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj diff --git a/README.md b/README.md index 8af5370befbb090966a8b3af54d80c84a969aaa5..044174947a094d43a51f7140dd40ec0f17801d40 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,14 @@ |-----------------| | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | -**TensorFlow** is an open source software library for numerical computation using -data flow graphs. The graph nodes represent mathematical operations, while +**TensorFlow** is an open source software library for numerical computation +using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture enables you to deploy computation to one -or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. +between them. This flexible architecture enables you to deploy computation to +one or more CPUs or GPUs in a desktop, server, or mobile device without +rewriting code. TensorFlow also includes +[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization +toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -111,22 +113,24 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA -**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA +**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA **IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) ## For more information -* [TensorFlow Website](https://www.tensorflow.org) -* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow Twitter](https://twitter.com/tensorflow) -* [TensorFlow Blog](https://medium.com/tensorflow) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) -* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) -* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) + +* [TensorFlow Website](https://www.tensorflow.org) +* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) +* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/WORKSPACE b/WORKSPACE index 0c7bc085b512b084b9470abe17326d7c119aa327..7cc08e0164a202581ad7ebbe107a9e19410e70e4 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,5 +1,7 @@ workspace(name = "org_tensorflow") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + http_archive( name = "io_bazel_rules_closure", sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", @@ -57,9 +59,9 @@ android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() -new_http_archive( +http_archive( name = "inception_v1", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", @@ -67,9 +69,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_ssd", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", @@ -77,9 +79,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_multibox", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", @@ -87,9 +89,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "stylize", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", @@ -97,9 +99,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "speech_commands", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", diff --git a/configure.py b/configure.py index 234561d94a46f57c4de5ca487360e2d5a3dfdb2f..6c905a0be3d685b5921dfbc5bddfbe6471a82625 100644 --- a/configure.py +++ b/configure.py @@ -238,6 +238,13 @@ def setup_python(environ_cp): write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path + # If choosen python_lib_path is from a path specified in the PYTHONPATH + # variable, need to tell bazel to include PYTHONPATH + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + if python_lib_path in python_paths: + write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) + # Write tools/python_bin_path.sh with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), @@ -445,11 +452,12 @@ def convert_version_to_int(version): return int(version_str) -def check_bazel_version(min_version): - """Check installed bazel version is at least min_version. +def check_bazel_version(min_version, max_version): + """Check installed bazel version is between min_version and max_version. Args: min_version: string for minimum bazel version. + max_version: string for maximum bazel version. Returns: The bazel version detected. @@ -467,6 +475,7 @@ def check_bazel_version(min_version): min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) + max_version_int = convert_version_to_int(max_version) # Check if current bazel version can be detected properly. if not curr_version_int: @@ -480,6 +489,10 @@ def check_bazel_version(min_version): print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) sys.exit(0) + if curr_version_int > max_version_int: + print('Please downgrade your bazel installation to version %s or lower to ' + 'build TensorFlow!' % max_version) + sys.exit(0) return curr_version @@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_paths_full = [ os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths ] - if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): + if any(os.path.exists(x) for x in cuda_toolkit_paths_full): break # Reset and retry @@ -1552,7 +1565,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0') + check_bazel_version('0.15.0', '0.20.0') reset_tf_configure_bazelrc() # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later @@ -1694,6 +1707,7 @@ def main(): config_info_line('nohdfs', 'Disable HDFS support.') config_info_line('noignite', 'Disable Apacha Ignite support.') config_info_line('nokafka', 'Disable Apache Kafka support.') + config_info_line('nonccl', 'Disable NVIDIA NCCL support.') if __name__ == '__main__': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 859dc3b8d77be66e0f51e15d86188399273af23f..fd4b94202aad24a82abef8abd16431f61a8326f0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = ( TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) +# @unused +TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( + TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -213,31 +218,37 @@ config_setting( # config_setting( name = "no_aws_support", - define_values = {"no_aws_support": "false"}, + define_values = {"no_aws_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_gcp_support", - define_values = {"no_gcp_support": "false"}, + define_values = {"no_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_hdfs_support", - define_values = {"no_hdfs_support": "false"}, + define_values = {"no_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_ignite_support", - define_values = {"no_ignite_support": "false"}, + define_values = {"no_ignite_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_kafka_support", - define_values = {"no_kafka_support": "false"}, + define_values = {"no_kafka_support": "true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_nccl_support", + define_values = {"no_nccl_support": "true"}, visibility = ["//visibility:public"], ) @@ -350,7 +361,7 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", - "//tensorflow_estimator/...", + "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", @@ -554,18 +565,24 @@ genrule( }), outs = ["__init__.py"], cmd = select({ - "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", - "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)", }), ) gen_api_init_files( name = "tf_python_api_gen_v1", - srcs = ["api_template_v1.__init__.py"], + srcs = [ + "api_template_v1.__init__.py", + "compat_template_v1.__init__.py", + ], api_version = 1, + compat_api_versions = [1], + compat_init_templates = ["compat_template_v1.__init__.py"], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1, + output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, output_package = "tensorflow._api.v1", + root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", ) @@ -581,6 +598,7 @@ gen_api_init_files( output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", + root_file_name = "v2.py", root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 0d49756838505289a960a6cabeb7cab02fad995b..d81cf067eb07e88e2b8a86cf5643674235eb3f3b 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -21,8 +21,6 @@ from __future__ import print_function as _print_function import os as _os # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import - from tensorflow.python.tools import component_api_helper as _component_api_helper _component_api_helper.package_hook( parent_package_str=__name__, @@ -30,16 +28,16 @@ _component_api_helper.package_hook( # API IMPORTS PLACEHOLDER -from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top - # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +# We're using bitwise, but there's nothing special about that. +_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) -# Calls to enable and disable features. -enable_eager_execution() # pylint: disable=undefined-variable +# Enable TF2 behaviors +from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +_compat.enable_v2_behavior() # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index b8db1b2144978e97bd32f62e643c2c4a7fcf1654..25df970ecab0757f23465ab19e7f45de0c759458 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -60,6 +60,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:op_gen_lib", + "//tensorflow/core/distributed_runtime:server_lib", ], }), ) @@ -120,7 +121,8 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -173,6 +175,60 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + +tf_cuda_library( + name = "kernels", + srcs = [ + "kernels.cc", + ], + hdrs = [ + "kernels.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":c_api_internal", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + ":tf_status_helper", + "//tensorflow/core:framework", + ], + }), +) + # ----------------------------------------------------------------------------- # Tests @@ -208,7 +264,10 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), - tags = ["noasan"], + tags = [ + "no_oss", # http://b/119522529 + "noasan", + ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -237,7 +296,7 @@ tf_cuda_cc_test( tf_cc_test( name = "c_api_experimental_test", - size = "small", + size = "medium", srcs = ["c_api_experimental_test.cc"], data = ["testdata/tf_record"], linkopts = select({ @@ -248,8 +307,11 @@ tf_cc_test( # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), deps = [ + ":c_api", ":c_api_experimental", ":c_test_util", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -300,6 +362,51 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cuda_cc_test( + name = "kernels_test", + size = "small", + srcs = ["kernels_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- # Python API target diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index f13e8777dff164bcd8eedf46310ae846abd0c804..94d18eb8b04e3534be547aca5cfbb32da40ffbf6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) { namespace { class TF_ManagedBuffer : public TensorBuffer { public: - void* data_; - size_t len_; - void (*deallocator_)(void* data, size_t len, void* arg); - void* deallocator_arg_; + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; ~TF_ManagedBuffer() override { - (*deallocator_)(data_, len_, deallocator_arg_); + (*deallocator_)(data(), len_, deallocator_arg_); } - void* data() const override { return data_; } size_t size() const override { return len_; } TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { @@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, dimvec[i] = static_cast(dims[i]); } - TF_ManagedBuffer* buf = new TF_ManagedBuffer; - buf->len_ = len; + TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != @@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf->data_ = allocate_tensor("TF_NewTensor", len); - std::memcpy(buf->data_, data, len); - buf->deallocator_ = deallocate_buffer; - buf->deallocator_arg_ = nullptr; + buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, + deallocate_buffer, nullptr); + std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf->data_ = data; - buf->deallocator_ = deallocator; - buf->deallocator_arg_ = deallocator_arg; + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; size_t elem_size = TF_DataTypeSize(dtype); if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { @@ -477,9 +480,9 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { CHECK_EQ(nelems, 0); static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast(dims.data()), - shape.dims(), reinterpret_cast(&empty), 0, - [](void*, size_t, void*) {}, nullptr); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); } // Non-static for testing. @@ -1592,18 +1595,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, break; \ } - LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); LIST_CASE(i, TF_ATTR_INT); LIST_CASE(f, TF_ATTR_FLOAT); LIST_CASE(b, TF_ATTR_BOOL); LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); LIST_CASE(tensor, TF_ATTR_TENSOR); LIST_CASE(tensor, TF_ATTR_FUNC); #undef LIST_CASE diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3d56268110edbe96616201d15a69cc8c84d3115a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -1611,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index fabe2fa0f60bc8baafa7f83802da74bb7ab93c6d..38e29aa74a90f4e85d1369b6928a5a58c531b2da 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,13 +15,18 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -6525,7 +6530,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/cycle_length" + name: "ExperimentalParallelInterleaveDataset/cycle_length" op: "Const" attr { key: "dtype" @@ -6546,7 +6551,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/block_length" + name: "ExperimentalParallelInterleaveDataset/block_length" op: "Const" attr { key: "dtype" @@ -6567,7 +6572,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/sloppy" + name: "ExperimentalParallelInterleaveDataset/sloppy" op: "Const" attr { key: "dtype" @@ -6588,7 +6593,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/buffer_output_elements" + name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" op: "Const" attr { key: "dtype" @@ -6609,7 +6614,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/prefetch_input_elements" + name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" op: "Const" attr { key: "dtype" @@ -6630,14 +6635,14 @@ library { } } node_def { - name: "ParallelInterleaveDataset" - op: "ParallelInterleaveDataset" + name: "ExperimentalParallelInterleaveDataset" + op: "ExperimentalParallelInterleaveDataset" input: "RepeatDataset:handle:0" - input: "ParallelInterleaveDataset/cycle_length:output:0" - input: "ParallelInterleaveDataset/block_length:output:0" - input: "ParallelInterleaveDataset/sloppy:output:0" - input: "ParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" + input: "ExperimentalParallelInterleaveDataset/block_length:output:0" + input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" + input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" attr { key: "Targuments" value { @@ -6737,7 +6742,7 @@ library { node_def { name: "ShuffleDataset_2" op: "ShuffleDataset" - input: "ParallelInterleaveDataset:handle:0" + input: "ExperimentalParallelInterleaveDataset:handle:0" input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0" @@ -8739,14 +8744,65 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } -TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, - const char* errMsg) { +struct TFE_ExecuteOpNotification { + TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} + tensorflow::Notification n; + std::unique_ptr thread; + std::unique_ptr status; +}; + +TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op, + TFE_TensorHandle** retvals, + int* num_retvals, + TF_Status* status) { + TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; + + n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( + tensorflow::ThreadOptions(), "ExecuteOpThread", + [op, retvals, num_retvals, n]() { + TFE_Execute(op, retvals, num_retvals, n->status.get()); + n->n.Notify(); + })); + + return n; +} + +void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status) { + if (notification == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification is a nullptr."); + + return; + } + if (notification->thread == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification didn't start a thread correctly. Cleaning up " + "this notification. Please re-execute the operation to get a new " + "notification."); + + delete notification; + return; + } + + notification->n.WaitForNotification(); + + status->status = notification->status->status; + + delete notification; +} + +void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } // This builder is used in the eager API to build a NodeDef. struct TF_AttrBuilder : public tensorflow::AttrBuilder { using tensorflow::AttrBuilder::AttrBuilder; + // The string buffers to make sure that any `attr_name` we pass into + // `builder->Set()` will outlive the subsequent + // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`. + std::set attr_names; }; TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) { @@ -8757,13 +8813,15 @@ void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; } void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, TF_DataType value) { - builder->Set(attr_name, static_cast(value)); + auto iter = builder->attr_names.insert(attr_name).first; + builder->Set((*iter).c_str(), static_cast(value)); } void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, const TF_DataType* values, int num_values) { + auto iter = builder->attr_names.insert(attr_name).first; builder->Set( - attr_name, + (*iter).c_str(), tensorflow::gtl::ArraySlice( reinterpret_cast(values), num_values)); } @@ -8800,3 +8858,31 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, // The returned string is owned by OpRegistry, so liveness is not a concern. return input_arg.number_attr().c_str(); } + +int TF_OpIsStateful(const char* op_type, TF_Status* status) { + const tensorflow::OpRegistrationData* op_reg_data; + status->status = + tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data); + if (!status->status.ok()) { + return 0; + } + return op_reg_data->op_def.is_stateful(); +} + +void TF_InitMain(const char* usage, int* argc, char*** argv) { + tensorflow::port::InitMain(usage, argc, argv); +} + +int TF_PickUnusedPortOrDie() { + return tensorflow::internal::PickUnusedPortOrDie(); +} + +TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, + void* data, size_t len) { + auto dtype = static_cast(dtype_arg); + DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype)); + + tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); + std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); + return new TFE_TensorHandle(tensor, nullptr, nullptr); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6639b0be72bdf81d0e3c806770364d7bc5082ad2..3e3a485eb763b871b0551414c4ef04746b2ed9a3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; + +// Allows invoking a kernel asynchronously, and explicitly returns a +// notification that can be waited upon. This always executes the kernel in a +// new thread. +// 1. `retvals` and `num_retvals` can only be consumed after +// `TFE_ExecuteOp` returns successfully. They shouldn't be used +// if the return is unsuccessful +// 2. These new APIs cannot be used together with the TFE context level async +// support. +TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread( + TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status); + +// Waits to complete the op execution, and cleans up the notification. +// Errors reported by op execution are set in `status`. +TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status); + TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); @@ -209,6 +228,24 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( const char* op_name, int input_index, TF_Status* status); +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type, + TF_Status* status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); + +// Fast path method that makes constructing a single scalar tensor require less +// overhead and copies. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( + TF_DataType dtype, void* scalar, size_t len); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index c6effd39697e0397278770b53e98508074f99862..daa7701b7fe7e8ce757b6504329cf6434ad39778 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -162,5 +164,137 @@ protocol: "grpc" TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, IsStateful) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + int assign = TF_OpIsStateful("AssignAddVariableOp", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(assign, 1); + int id = TF_OpIsStateful("Identity", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(id, 0); +} + +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + TFE_Op* matmul_op = MatMulOp(ctx, m, m); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + auto* r = + TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(r, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteOp(matmul_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + +// Perform a send/recv test. Recv blocks, so they need to be executed +// asynchronously. +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4. + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + // Build a send op. + TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(send_op, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + string tensor_name = "Tensor"; + TFE_OpSetAttrType(send_op, "T", TF_FLOAT); + TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + string send_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234); + string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(send_op, "client_terminated", true); + + // Build a recv op. + TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT); + TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234); + TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(recv_op, "client_terminated", true); + + TFE_TensorHandle* send_retvals; + int send_num_retvals = 0; + auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals, + &send_num_retvals, status); + + TFE_TensorHandle* recv_retvals[1] = {nullptr}; + int recv_num_retvals = 1; + auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0], + &recv_num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(send_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, product[0]); + EXPECT_EQ(2, product[1]); + EXPECT_EQ(3, product[2]); + EXPECT_EQ(4, product[3]); + + TFE_DeleteOp(send_op); + TFE_DeleteOp(recv_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(recv_retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index f68f8a3e90a971b5e4a024feaf26ba498afc48da..28b9f8df9c873ee394eb6a241dd9ac06ba6c8796 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -392,26 +392,26 @@ Status ProcessInputs( EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { input_tensors->reserve(ninputs); for (int i = 0; i < ninputs; ++i) { - const Node& node = inputs[i].oper->node; + Node* node = &inputs[i].oper->node; int idx = inputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - input_tensors->emplace_back(&node, idx); + input_tensors->emplace_back(node, idx); - const auto& iter = input_nodes->find(&node); + const auto& iter = input_nodes->find(node); if (iter == input_nodes->end()) { - input_nodes->insert({&node, {idx}}); + input_nodes->insert({node, {idx}}); } else { auto& indices = iter->second; if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { - return InvalidArgument("TF_Output ", node.name(), ":", idx, + return InvalidArgument("TF_Output ", node->name(), ":", idx, " appears more than once in the input list"); } indices.push_back(idx); @@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { - const Node& node = outputs[i].oper->node; + Node* node = &outputs[i].oper->node; int idx = outputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing output ", i, " from function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while creating function '", fn_name, "'"); - output_tensors->emplace_back(&node, idx); + output_tensors->emplace_back(node, idx); } return Status::OK(); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index ba3d8533db7623b8fa7fdf35093abcd1450776b1..c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -50,6 +50,7 @@ tf_cuda_library( ], "//conditions:default": [], }) + [ + "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -143,6 +144,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 408277468d7beb23d1b2ab7f9bbccac16332e55a..027d752f420238da867cb9d8c116640e1730caaa 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/host_info.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, std::unique_ptr* device_mgr) { - std::vector remote_devices; + std::vector> remote_devices; tensorflow::Status status; // TODO(nareshmodi) do this in parallel instead of serially. for (const string& remote_worker : remote_workers) { @@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices( status = s; if (s.ok()) { for (tensorflow::Device* d : *devices) { - remote_devices.push_back(d); + remote_devices.emplace_back(d); } } n.Notify(); @@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices( n.WaitForNotification(); } std::unique_ptr remote_device_mgr( - new tensorflow::DeviceMgr(remote_devices)); + new tensorflow::DeviceMgr(std::move(remote_devices))); TF_RETURN_IF_ERROR(status); @@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - std::vector devices; + std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( - new tensorflow::DeviceMgr(devices)); + new tensorflow::DeviceMgr(std::move(devices))); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); @@ -409,6 +411,18 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, + TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + tensorflow::Device* d = h->handle->device(); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); +} + TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { @@ -458,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { const char* name = op_or_function_name; // Shorthand const tensorflow::AttrTypeMap* types; - status->status = tensorflow::AttrTypeMapForOp(name, &types); - if (status->status.ok()) return new TFE_Op(ctx, name, types); - if (TF_GetCode(status) == TF_NOT_FOUND) { - if (ctx->context.FindFunctionByName(name)) { - status->status = tensorflow::Status::OK(); - return new TFE_Op(ctx, name, nullptr); + bool is_function = false; + status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); + if (status->status.ok()) { + if (is_function && !ctx->context.FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; } + return new TFE_Op(ctx, name, is_function, types); } return nullptr; } @@ -497,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->operation.is_function()) { - status->status = tensorflow::errors::Unimplemented( - "TODO(apassos): Support for attributes for TensorFlow functions is not " - "ready yet."); - return TF_ATTR_INT; // The compiler requires that we return something. - } status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), attr_name, &ret, is_list); return ret; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index b2454d872207e26feb3764671474a5d87c01f84d..f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. @@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); + +// Returns the device of the operation that produced `h`. +// If `h` was produced by a copy, returns the destination device of +// the copy. Note that returned device name is not always the device +// holding the tensor handle's memory. If you want the latter, use +// TFE_TensorHandleBackingDeviceName. +// This function will block till the operation that produces `h` has completed. +// +// Device on which the kernel of the operation that produced `h` ran. +// +// If `h` was produced by a copy, returns the destination device of +// the copy. +// +// Note that returned device name is not always the device that owns the memory +// that backs the tensor handle. For the latter see +// TFE_TensorHandleBackingDeviceName. +// // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Returns the name of the device in whose memory `h` resides. +// +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName( + TFE_TensorHandle* h, TF_Status* status); + // Return a pointer to a new TFE_TensorHandle that shares the underlying tensor // with `h`. On success, `status` is set to OK. On failure, `status` reflects // the error and a nullptr is returned. diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index fa1b22e3af487b19b8b7885b7c3740b6249c73eb..67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -93,10 +93,9 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a - // primitive operation. - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : operation(&ctx->context, op, t) {} + TFE_Op(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t) + : operation(&ctx->context, op, is_function, t) {} tensorflow::EagerOperation operation; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 55331022b9dbd0696928fa44430f340f371432ac..6b39b79ee82f9c7baaf856e573a42b7da65691e5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "absl/strings/match.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); const int num_devices = TF_DeviceListCount(devices); + bool has_gpu0 = false; + bool has_gpu1 = false; + for (int i = 0; i < num_devices; ++i) { + const char* dev = TF_DeviceListName(devices, i, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + string device_name(dev); + if (device_name.find("GPU:0") != string::npos) { + has_gpu0 = true; + } + if (device_name.find("GPU:1") != string::npos) { + has_gpu1 = true; + } + } const char* kCPUDevice = "CPU:0"; - if (num_devices < 3) { + if (!has_gpu0 || !has_gpu1) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); @@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) { TF_SetStatus(status.get(), TF_OK, ""); + device_name = TFE_TensorHandleBackingDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + int num_dims = TFE_TensorHandleNumDims(h, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); ASSERT_EQ(num_dims, -1); @@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) { string(TF_Message(status.get()))); } +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name; + const char* backing_device_name = + TFE_TensorHandleBackingDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name; + + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = + TFE_TensorHandleBackingDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 008f088c2dcdd7d9114103516a4702e47a55c6de..bd38127d50c171af801dd1b937acefdba491b4a6 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Shape", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + TFE_TensorHandle* TestAxisTensorHandle() { int64_t dims[] = {1}; int data[] = {1}; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 474cae67c89249af3a62707f0db00ba458ca8f31..75ef9459e93b4f2ed471c423a34565594efc1714 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(); // Return a matmul op multiplying `a` by `b`. TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); +// Return a shape op fetching the shape of `a`. +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); + // Return an 1-D INT32 tensor containing a single value 1. TFE_TensorHandle* TestAxisTensorHandle(); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5ba55a203ff70cc64c07e96b5a869a1f11c9334e..5c11f51e8749de84547ae873f5f55ebd42bc4b3d 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -141,8 +141,9 @@ class GradientTape { // null. The result is populated with one tensor per target element. Status ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result); @@ -396,6 +397,7 @@ template Status InitialGradients( const VSpace& vspace, gtl::ArraySlice target_tensor_ids, + gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, gtl::FlatMap>* result) { @@ -425,8 +427,13 @@ Status InitialGradients( "none of operations outputs match expected tensor"); } } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + // This target tensor was not generated by any operation recorded on + // the tape, so no gradient needs to be computed from it unless this + // target is also a source. + auto source_tensor = sources_that_are_targets.find(id); + if (source_tensor != sources_that_are_targets.end()) { + (*result)[id].push_back(vspace.Ones(source_tensor->second)); + } } } else { (*result)[id].push_back(output_gradients[i]); @@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template Status GradientTape::ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_ids, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), @@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient( std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; - Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..07b9e8b940c55caf62ae0b81b884bf313d335459 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..9d27c5da37735042c7476b591e57486dbde33152 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +struct TF_WritableFileHandle; +struct TF_StringStream; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2206c6befd2167346c64032940d6e8c631e4a3e --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a4eaecb6cf2740a522b1e849d1306ebde6c4577 --- /dev/null +++ b/tensorflow/c/kernels.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +// This file forms the basis of a stable ABI for third-party kernel +// implementations. It is crucial that changes to this file are made cautiously +// and with a focus on maintaining both source and binary compatibility. + +struct TF_KernelBuilder { + ::tensorflow::KernelDefBuilder* cc_builder; + + void* (*create_function)(TF_OpKernelConstruction*); + void (*compute_function)(void*, TF_OpKernelContext*); + void (*delete_function)(void*); +}; + +TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) { + TF_KernelBuilder* result = new TF_KernelBuilder; + result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name); + result->cc_builder->Device(device_name); + result->create_function = create_func; + result->compute_function = compute_func; + result->delete_function = delete_func; + return result; +} + +void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { + DCHECK_NE(builder, nullptr); + delete builder->cc_builder; + delete builder; +} + +namespace tensorflow { +namespace { + +// An OpKernel whose methods delegate to C function pointers. +class COpKernel : public OpKernel { + public: + explicit COpKernel(OpKernelConstruction* ctx, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) + : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) { + if (create_func != nullptr) { + c_kernel_ = + (*create_func)(reinterpret_cast(ctx)); + } else { + c_kernel_ = nullptr; + } + } + + void Compute(OpKernelContext* ctx) override { + (*compute_func_)(c_kernel_, reinterpret_cast(ctx)); + } + + ~COpKernel() override { + if (delete_func_ != nullptr) { + (*delete_func_)(c_kernel_); + } + } + + private: + void (*compute_func_)(void*, TF_OpKernelContext* context); + void (*delete_func_)(void*); + void* c_kernel_; +}; + +// A KernelFactory that returns COpKernel instances. +class KernelBuilderFactory + : public ::tensorflow::kernel_factory::OpKernelFactory { + public: + explicit KernelBuilderFactory(TF_KernelBuilder* builder) + : builder_(builder) {} + ::tensorflow::OpKernel* Create( + ::tensorflow::OpKernelConstruction* context) override { + return new ::tensorflow::COpKernel(context, builder_->create_function, + builder_->compute_function, + builder_->delete_function); + } + ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); } + + private: + TF_KernelBuilder* builder_; +}; +} // namespace +} // namespace tensorflow + +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, + TF_Status* status) { + using tensorflow::register_kernel::Name; + + tensorflow::kernel_factory::OpKernelRegistrar( + builder->cc_builder->Build(), name, + absl::make_unique(builder)); + + TF_SetStatus(status, TF_OK, ""); +} + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} + +void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + ::tensorflow::Tensor cc_tensor; + ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + cc_ctx->set_output(i, cc_tensor); + } +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..1a91aa184f11ac8e45b38a1d106c7b445747a7c1 --- /dev/null +++ b/tensorflow/c/kernels.h @@ -0,0 +1,118 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_H_ +#define TENSORFLOW_C_KERNELS_H_ + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for TensorFlow Kernels. +// +// This API allows developers to register custom kernel implementations for +// TensorFlow. +// +// See c_api.h header comments for a discussion about API conventions. +// +// Users wishing to extend TensorFlow with new kernels will call +// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with +// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided +// kernels when necessary. + +struct TF_KernelBuilder; +struct TF_OpKernelConstruction; +struct TF_OpKernelContext; + +// Allocates a new kernel builder and returns a pointer to it. +// +// If non-null, TensorFlow will call create_func when it needs to instantiate +// the kernel. The pointer returned by create_func will be passed to +// compute_func and delete_func, thereby functioning as a "this" pointer for +// referring to kernel instances. +// +// The TF_OpKernelConstruction pointer passed to create_func is owned by +// TensorFlow and will be deleted once create_func returns. It must not be used +// after this. +// +// When TensorFlow needs to perform a computation with this kernel, it will +// call compute_func. This function will receive the pointer returned by +// create_func (or null if no create_func was provided), along with the inputs +// to the computation. +// +// The TF_OpKernelContext pointer received by compute_func is owned by +// TensorFlow and will be deleted once compute_func returns. It must not be used +// after this. +// +// Finally, when TensorFlow no longer needs the kernel, it will call +// delete_func if one is provided. This function will receive the pointer +// returned in `create_func` or nullptr if no `create_func` was provided. +// +// The caller should pass the result of this function to +// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for +// some reason, the kernel builder will not be registered, the caller should +// delete it with TF_DeleteKernelBuilder. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This call takes ownership of the `builder` pointer. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, + TF_KernelBuilder* builder, + TF_Status* status); + +// Deletes the given TF_KernelBuilder. This should be called only if the kernel +// builder is not registered with TensorFlow via TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); + +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_H_ diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e659ee3c3d258a626ccf03a782ec031b5a703a48 --- /dev/null +++ b/tensorflow/c/kernels_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +struct MyCustomKernel { + bool created; + bool compute_called; +}; + +static bool delete_called = false; + +static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + return s; +} + +static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { + struct MyCustomKernel* s = static_cast(kernel); + s->compute_called = true; +} + +static void MyDeleteFunc(void* kernel) { + struct MyCustomKernel* s = static_cast(kernel); + EXPECT_TRUE(s->created); + EXPECT_TRUE(s->compute_called); + delete_called = true; + delete s; +} + +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + +// Tests registration of a single C kernel and checks that calls through the +// C/C++ boundary are being made. +TEST(TestKernel, TestRegisterKernelBuilder) { + const char* kernel_name = "SomeKernelName"; + const char* op_name = "FooOp"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + KernelList list; + list.ParseFromArray(buf->data, buf->length); + ASSERT_EQ(1, list.kernel_size()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); + TF_DeleteBuffer(buf); + TF_DeleteStatus(status); + } + + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + // Copy the input tensor to output. + TF_SetOutput(ctx, 0, input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_SetOutput(ctx, 24, input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Updates 'dst' to consume 'new_src'. void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); @@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // because I couldn't get SWIG to work otherwise. void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 83353b79f722f0a95f508b32d4a49b14b35624fb..a09becc49b10d2c58f98fbcc11df5190f794c1d4 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -489,6 +489,7 @@ tf_gen_op_wrappers_cc( "image_ops", "io_ops", "linalg_ops", + "list_ops", "logging_ops", "lookup_ops", "manip_ops", diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 3d3895c8fa82c3c0e2974228e9cad767d0e00df4..52345a376cc29ee47ccb9888c9bb26292468b5a9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -133,5 +133,6 @@ filegroup( "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", + "testdata/half_plus_two_v2/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 645a3f101d1ae7dda88ec4ca622c694dc5a7a919..6f00dc324bd7054b28de2c35023581e1666bfa01 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; /// SavedModel text format proto filename. constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; -/// SavedModel legacy init op key. +/// SavedModel legacy init op collection key. Used in v1 SavedModels. constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; -/// SavedModel main op key. +/// SavedModel main op collection key. Used in v1 SavedModels. constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; /// Directory in which to save the SavedModel variables. @@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables"; /// SavedModel variables filename. constexpr char kSavedModelVariablesFilename[] = "variables"; +/// SavedModel SignatureDef keys for the initialization and train ops. Used in +/// V2 SavedModels. +constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; +constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e..85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -122,34 +122,54 @@ Status RunOnce(const RunOptions& run_options, return run_status; } -bool HasMainOp(const MetaGraphDef& meta_graph_def) { +// RunInitOp will return OK if the initialization op was run successfully. +// An empty init_op_name indicates that there are no init ops to run. +Status RunInitOp(const RunOptions& run_options, const string& export_dir, + const MetaGraphDef& meta_graph_def, + const std::vector& asset_file_defs, + Session* session, const string& init_op_name) { + if (!init_op_name.empty()) { + LOG(INFO) << "Running initialization op on SavedModel bundle."; + std::vector> inputs; + AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); + RunMetadata run_metadata; + return RunOnce(run_options, inputs, {}, {init_op_name}, + nullptr /* outputs */, &run_metadata, session); + } + return Status::OK(); +} + +// A SavedModel may store the name of the initialization op to run in the +// in the SignatureDef (v2) or a collection (v1). If an init_op collection +// exists, then the collection must contain exactly one op. +Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, + string* init_op_name) { + const auto& sig_def_map = meta_graph_def.signature_def(); + const auto& init_op_sig_it = + meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey); + if (init_op_sig_it != sig_def_map.end()) { + *init_op_name = init_op_sig_it->second.outputs() + .find(kSavedModelInitOpSignatureKey) + ->second.name(); + return Status::OK(); + } + const auto& collection_def_map = meta_graph_def.collection_def(); + string init_op_collection_key; if (collection_def_map.find(kSavedModelMainOpKey) != collection_def_map.end()) { - return true; + init_op_collection_key = kSavedModelMainOpKey; + } else { + init_op_collection_key = kSavedModelLegacyInitOpKey; } - return false; -} -Status RunMainOp(const RunOptions& run_options, const string& export_dir, - const MetaGraphDef& meta_graph_def, - const std::vector& asset_file_defs, - Session* session, const string& main_op_key) { - LOG(INFO) << "Running MainOp with key " << main_op_key - << " on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(main_op_key); - if (main_op_it != collection_def_map.end()) { - if (main_op_it->second.node_list().value_size() != 1) { + const auto init_op_it = collection_def_map.find(init_op_collection_key); + if (init_op_it != collection_def_map.end()) { + if (init_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( strings::StrCat("Expected exactly one main op in : ", export_dir)); } - std::vector> inputs; - AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); - RunMetadata run_metadata; - const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {string(main_op_name)}, - nullptr /* outputs */, &run_metadata, session); + *init_op_name = init_op_it->second.node_list().value(0); } return Status::OK(); } @@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, std::vector* asset_file_defs) { + // With SavedModel v2, we write asset file def into metagraph instead of + // collection, so read from metagraph first. + if (meta_graph_def.asset_file_def_size() > 0) { + for (const auto& asset : meta_graph_def.asset_file_def()) { + asset_file_defs->push_back(asset); + } + return Status::OK(); + } + // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { @@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); - if (HasMainOp(bundle->meta_graph_def)) { - TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelMainOpKey)); - } else { - TF_RETURN_IF_ERROR(RunMainOp( - run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelLegacyInitOpKey)); - } + string init_op_name; + TF_RETURN_IF_ERROR( + GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); + TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, + asset_file_defs, bundle->session.get(), + init_op_name)); return Status::OK(); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2..597e42bb65ab5536664089f7e65ec52d77fc8f23 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] = "cc/saved_model/testdata/half_plus_two_main_op/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; +constexpr char kTestDataInitOpV2[] = + "cc/saved_model/testdata/half_plus_two_v2/00000123"; class LoaderTest : public ::testing::Test { protected: @@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) { EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); } +TEST_F(LoaderTest, SavedModelInitOpV2Format) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9ff036688007836524129e23f5cf82edd1e8910 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..a10bbf8fb6bca0fcee6414b2927d2f706de85ebc Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..15b75d6ef6bffc336d138d923badb3928b8c4c13 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..7ec9fb4fe2dd21d0a6c324aecd7658fc37cf2326 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index differ diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index b17bc658fa06b9feb7edb292bd89ef31e6309169..ab1c1be344e2257721507543bc7647d4ff4becb2 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, +Status GenArgMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); if (config.feed_size() != num_args) { @@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, } for (int i = 0; i < num_args; ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { @@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, // Generate methods for results (outputs). Status GenResultMethods(const tf2xla::Config& config, - const xla::ProgramShape& ps, string* methods) { + const xla::ProgramShapeProto& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config, } for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR( - AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { return static_cast<{{TYPE}}*>(result_data({{I}})); @@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, ExtractEntryParamBufferInfos(buffer_infos); std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); - const xla::ProgramShape& ps = compile_result.program_shape; + const xla::ProgramShapeProto& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); @@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } @@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, - {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, @@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts, Status GenerateMetadata(const CodegenOpts& opts, const CompileResult& compile_result, MetadataResult* metadata_result) { - std::unique_ptr program_shape; + std::unique_ptr program_shape; if (opts.gen_program_shape) { program_shape = - absl::make_unique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save @@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts, // a shim that evaluates to nullptr, which is what we want. ProtobufToEmbed program_shape_protobuf{ - CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", - program_shape.get()}; + CreateUniqueIdentifier(opts, "ProgramShapeProto"), + "xla::ProgramShapeProto", program_shape.get()}; ProtobufToEmbed hlo_profile_printer_data_protobuf{ CreateUniqueIdentifier(opts, "HloProfilePrinterData"), diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 90410c46a8e36e44454f1219ad76d0fb0937070d..9485e86b10e225a3c9c12eafd9905bdf7c15c9fa 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -57,7 +57,7 @@ struct MetadataResult { std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the - // xla::ProgramShape instance for the CompileResult passed to + // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. string program_shape_access_shim; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index bb288d23000527be74f01630d20bbf82e50007ce..c1788ca32a1d099284eeb870f9513891051fd29e 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) { BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, 5, {})); - compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( - { - xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), - xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); + compile_result.program_shape = + xla::ShapeUtil::MakeProgramShape( + { + xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), + xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + }, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index e4d8a02877c75fa72c5747650ab9c7ac229955b3..968afad65ed6d4b5510687df484b7ce6743f6a85 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -22,7 +22,7 @@ extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, tensorflow::int64* profile_counters); -extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[]; namespace foo { @@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { @@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = []() { - xla::ProgramShape* proto = new xla::ProgramShape; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = []() { + xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index eb001c5d45bdfefc76629d7303d89f5480432235..ce8e5ec8c96a2c3696f14b8eea206d648182ecb5 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b5f97b34cd928d32eb220536342c715d91d45bb..9fc223bdc7c0e207ce2005cb86250aa77e709df8 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client, return errors::Unknown("Couldn't get XLA program shape: ", pshape_or.status().error_message()); } - compile_result->program_shape = *pshape_or.ValueOrDie(); - xla::ProgramShape* pshape = &compile_result->program_shape; - std::vector arg_layouts; - arg_layouts.reserve(pshape->parameters_size()); + compile_result->program_shape = pshape_or.ValueOrDie()->ToProto(); + xla::ProgramShapeProto* pshape = &compile_result->program_shape; + + // AotXlaComputationInstance::argument_layouts is a vector of Shape + // pointers. Accumulate the Shape objects themselves in a separate vector + // while building the vector of pointers. + std::vector arg_layout_ptrs(pshape->parameters_size()); + std::vector arg_layouts(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { - arg_layouts.push_back(pshape->mutable_parameters(i)); + arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i)); + arg_layout_ptrs[i] = &arg_layouts[i]; } xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; - instance.argument_layouts = std::move(arg_layouts); - instance.result_layout = &pshape->result(); + instance.argument_layouts = std::move(arg_layout_ptrs); + xla::Shape result_shape(pshape->result()); + instance.result_layout = &result_shape; xla::StatusOr>> aot_or = client->CompileAheadOfTime({instance}, aot_opts); if (!aot_or.ok()) { diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index e03c5b1aa77c1262ed903aae3072ef65f34d80a2..ee7bb26fabd2d897b85b62f38778ecbfe2238eb6 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -33,9 +33,9 @@ namespace tfcompile { struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; - xla::ProgramShape program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. - int pointer_size = 0; // Size of a pointer in bytes. + xla::ProgramShapeProto program_shape; // Static shape of args and results. + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. }; // CompileGraph compiles the graph_def into an object file containing a function diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f10852c7850f61bfd8b99fa9f1648202d182085e..4dd79e5882d7da61be029735ef2b165908c599f9 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) { // muladd has the program shape defined. MatMulAndAddComp muladd; - const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape(); ASSERT_TRUE(muladd_shape != nullptr); ASSERT_EQ(muladd_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2)); - const xla::Shape& muladd_result = muladd_shape->result(); + const xla::Shape muladd_result(muladd_shape->result()); ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); const xla::Shape& muladd_result0 = diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 162a137fa7a5573056911d19472de4261574137a..15dcbb2641eca031e82db9aa58dee6a14ab0a2cc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -23,7 +23,6 @@ package( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -38,7 +37,7 @@ cc_library( ":xla_cpu_device", ":xla_cpu_jit", "//tensorflow/compiler/plugin", - ] + if_cuda_is_configured([ + ] + if_cuda([ ":xla_gpu_device", ":xla_gpu_jit", ]), @@ -51,6 +50,7 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -76,10 +76,11 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep + ":flags", ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -95,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -104,6 +106,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -210,6 +213,18 @@ cc_library( # Internal targets below this point. +cc_library( + name = "flags", + srcs = ["flags.cc"], + hdrs = ["flags.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "common", srcs = [ @@ -256,6 +271,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -268,6 +284,7 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -487,6 +504,7 @@ cc_library( deps = [ ":common", ":encapsulate_util", + ":flags", ":shape_inference_helpers", ":union_find", ":xla_cluster_util", @@ -494,8 +512,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", @@ -544,25 +560,6 @@ cc_library( hdrs = ["union_find.h"], ) -cc_library( - name = "producer_consumer_queue", - hdrs = ["producer_consumer_queue.h"], - deps = ["//tensorflow/core:lib"], -) - -tf_cc_test( - name = "producer_consumer_queue_test", - size = "small", - srcs = ["producer_consumer_queue_test.cc"], - deps = [ - ":producer_consumer_queue", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -743,7 +740,10 @@ tf_custom_op_py_library( visibility = [ ":friends", ], - deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], + deps = [ + "//tensorflow/compiler/jit/ops:xla_ops_grad", + "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py", + ], ) # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 93637a69d5d7b6bf9e9ce784ae521ef0e9b121b9..9f4042630edaec1b9519b6434d859a48372e8b15 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -320,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { return IsXlaCompiledKernel(*n); }); - bool lazy_compilation_enabled = enable_lazy_compilation_ - ? *enable_lazy_compilation_ - : legacy_flags::GetBuildXlaOpsPassFlags() - .tf_xla_enable_lazy_compilation; + bool lazy_compilation_enabled = + enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 11df946cc186660242574c2644463a26ead44f1f..48a23a4c1711ac88a329723c46559112d5a39dbd 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test { .ok()); } - void TearDown() override { - for (Device* device : devices_) { - delete device; - } - } - private: - std::vector devices_; + std::vector> devices_; }; using ::tensorflow::testing::FindNodeByName; diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index 73866607621cd745f6e640a14405daebf0dd9985..0f872a480f4d4843217f1df3452c4dc62531264e 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 1}); + std::vector> devices; TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); + options, "/job:localhost/replica:0/task:0", &devices)); FunctionDefLibrary proto; for (const auto& fdef : flib) { @@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test { lib_def_ = absl::make_unique( OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = absl::make_unique(devices_); + device_mgr_ = absl::make_unique(std::move(devices)); pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); @@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test { } FunctionLibraryRuntime* flr_; - std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr pflr_; diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 28ec37b1b9c8a1a306b5e778bac5b6ba01c2c997..1f4b9c90a4ff0b1166cdb7b5942771b350740ef3 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, continue; } else if (src_xla_computation && !dst_xla_computation) { if (src_outside_compilation) { - // Case 1d: outside compilation to host computation control edge. + // Case 1c: outside compilation to host computation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } } else if (!src_xla_computation && dst_xla_computation) { if (dst_outside_compilation) { - // Case 1d: host computation control to outside compilation edge. + // Case 1c: host computation control to outside compilation edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } else { // src_xla_computation && dst_xla_computation if (*src_xla_computation != *dst_xla_computation) { if (src_outside_compilation && dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. + // Case 1b: outside compilation to outside compilation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1b: outside compilation to another XLA computaition control + // Case 1a: outside compilation to another XLA computaition control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->src(), kXlaConnectedToOtherXlaComputationAttrName, *dst_xla_computation)); } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1b: another XLA computaition to outside compilation control + // Case 1a: another XLA computaition to outside compilation control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, *src_xla_computation)); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation) { - if (*src_outside_compilation != *dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to its XLA computation control edge. - ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: XLA computation to outside compilation in it control edge. - ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); - } } } } @@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g, edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation && - *src_outside_compilation != *dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } } } @@ -263,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. - std::map placeholders; + std::map, Node*> placeholders; for (int i = 0; i < edges.size(); i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; @@ -275,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( // Find or create placeholder node. string new_name = edges[i].is_host_to_outside_compilation - ? absl::StrCat(src->name(), "_host_to_oc_placeholder") - : absl::StrCat(src->name(), "_oc_to_host_placeholder"); - auto iter = placeholders.find(new_name); + ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) + : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); Node* placeholder_node; if (iter == placeholders.end()) { NodeDefBuilder placeholder_builder(new_name, "Placeholder"); @@ -310,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( Status s; placeholder_node = g->AddNode(placeholder_def, &s); TF_RETURN_IF_ERROR(s); - placeholders[new_name] = placeholder_node; + placeholders[placeholder_index] = placeholder_node; } else { placeholder_node = iter->second; } @@ -594,14 +573,244 @@ Status AddControlDependencies( return Status::OK(); } +// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges to remove. We should not remove the edge while iterating. + std::vector edges_to_remove; + for (const Edge* e : g->edges()) { + if (!e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation) { + if (*src_outside_compilation != *dst_outside_compilation) { + // Case 1a: outside compilation to outside compilation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, + e->src()->name())); + } + } else if (src_outside_compilation && !dst_outside_compilation) { + // Case 1b: outside compilation to its XLA computation control edge. + ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); + } else if (!src_outside_compilation && dst_outside_compilation) { + // Case 1b: XLA computation to outside compilation in it control edge. + ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); + } + } + + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + return Status::OK(); +} + +// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges between outside compilation and host computation. Notice that + // we do not store `Edge*` directly because we remove some nodes while adding + // Identity nodes, and those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation && + *src_outside_compilation != *dst_outside_compilation) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); + VLOG(4) << "Oc -> oc edge: " << e->DebugString(); + } + } + + // Remove the edge from host to outside compilation. Add a placeholder as + // outside compilation node input. + std::map, Node*> placeholders; + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Find or create placeholder node. + string new_name = + absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); + Node* placeholder_node; + if (iter == placeholders.end()) { + NodeDefBuilder placeholder_builder(new_name, "Placeholder"); + placeholder_builder.Attr("dtype", src->output_type(src_output)); + string outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), + outside_compilation_attr_name, + &outside_compilation_attr)); + placeholder_builder.Attr(outside_compilation_attr_name, + outside_compilation_attr); + placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName, + src_output); + NodeDef placeholder_def; + TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); + Status s; + placeholder_node = g->AddNode(placeholder_def, &s); + TF_RETURN_IF_ERROR(s); + placeholders[placeholder_index] = placeholder_node; + } else { + placeholder_node = iter->second; + } + g->AddEdge(placeholder_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = placeholder_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with + // corresponding edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather all outside compilation to outside compilation nodes. + std::vector placeholder_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Placeholder" && + HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) { + placeholder_nodes.push_back(n); + } + } + + // Remove the placeholder nodes, and reconnect original edge. + auto node_name_index = g->BuildNodeNameIndex(); + for (auto n : placeholder_nodes) { + string node_name; + int node_src_output; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output)); + auto iter = node_name_index.find(node_name); + if (iter == node_name_index.end()) { + return errors::Internal( + "Cannot find original node for oc -> host placeholder node ", + node_name); + } + + // Change all usage node to use the original node instead. + Node* original_node = iter->second; + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(original_node, e->dst()); + g->RemoveEdge(e); + } + for (int i = 0; i < data_edges.size(); i++) { + Node* dst = data_edges[i].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[i].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(original_node->name(), ":", node_src_output); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(original_node, node_src_output, replace_node, dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int j = i + 1; j < data_edges.size(); j++) { + if (data_edges[j].dst == dst) { + data_edges[j].dst = replace_node; + } + } + + // Other placeholder node might have `dst` as original node. Update + // `node_name_index` with `replace_node`. + node_name_index[replace_node->name()] = replace_node; + } + + // Remove placeholder node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + auto node_name_index = g->BuildNodeNameIndex(); + + // Reconnect outside compilation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, + &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); + for (const string& control_input : control_deps) { + auto iter = node_name_index.find(control_input); + if (iter == node_name_index.end()) { + return errors::Internal("Cannot find original node for ", + control_input); + } + g->AddControlEdge(iter->second, n); + } + } + } + return Status::OK(); +} } // namespace const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToXlaComputationAttrName[] = - "_xla_connected_to_xla_computation"; -const char kXlaConnectedFromXlaComputationAttrName[] = - "_xla_connected_from_xla_computation"; const char kXlaConnectedToOtherXlaComputationAttrName[] = "_xla_connected_to_other_xla_computation"; const char kXlaConnectedFromOtherXlaComputationAttrName[] = @@ -616,6 +825,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] = "_xla_host_to_oc_node_name"; const char kHostToOutsideCompilationSrcOutputAttrName[] = "_xla_host_to_oc_src_output"; +const char kXlaConnectedToXlaComputationAttrName[] = + "_xla_connected_to_xla_computation"; +const char kXlaConnectedFromXlaComputationAttrName[] = + "_xla_connected_from_xla_computation"; +const char kOutsideCompilationOriginalNodeAttrName[] = + "_xla_oc_to_oc_node_name"; +const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; +const char kXlaControlDependenciesWithinXlaClusterAttrName[] = + "_xla_control_dependencies_within_xla_cluster"; Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, @@ -699,4 +917,39 @@ Status PostprocessForEncapsulation( return Status::OK(); } +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Remove edges from source node to outside compilation nodes, and edges + // from outside compilation nodes to sink node. + std::vector edges_to_remove; + for (const Edge* e : g->source_node()->out_edges()) { + if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (const Edge* e : g->sink_node()->in_edges()) { + if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + + TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 5e0c4bf6a0cc92d69209595e257989665404db6b..e363bc5754ac395bae262dc67a780a0173efaf5e 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name); -// Attribute indicating that some ops in this node's XLA computation has control -// dependency on this node. Attribute value will always be "true". -extern const char kXlaConnectedToXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// this node's XLA computation. Attribute value will always be "true". -extern const char kXlaConnectedFromXlaComputationAttrName[]; - // Attribute indicating that some ops in other XLA computation has control // dependency on this node. Attribute value will be a list of string (XLA // computation names). @@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; // int (src_output for original edge). extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +// Attribute indicating that some ops in this node's XLA computation has control +// dependency on this node. Attribute value will always be "true". +extern const char kXlaConnectedToXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// this node's XLA computation. Attribute value will always be "true". +extern const char kXlaConnectedFromXlaComputationAttrName[]; + // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an host node. Attribute value will be string // (original input node name). @@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; // for original edge). extern const char kHostToOutsideCompilationSrcOutputAttrName[]; -// Preprocesses the graph for encapsulation. It will perform the following -// operations in order: +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes within the same XLA cluster. Attribute value will be a list of string +// (node names). +extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; + +// Preprocesses edges between different XLA clusters for encapsulation. It will +// perform the following operations in order: // -// 1a. For control edges between outside compilation and its XLA computation, -// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the -// outside compilation node. -// 1b. For control edges between outside compilation and another XLA +// 1a. For control edges between outside compilation and another XLA // computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName // = XLA computation node name" to the outside compilation node. -// 1c. For control edges between different outside compilations, remove the edge -// and add attr "kXlaControlDependenciesAttrName = src node name" to dst -// node. -// 1d. For control edges between outside compilation and host computation, +// 1b. For control edges between different outside compilations (in different +// XLA computations), remove the edge and add attr +// "kXlaControlDependenciesAttrName = src node name" to dst node. +// 1c. For control edges between outside compilation and host computation, // remove the edge and add attr "kXlaControlDependenciesAttrName = src node // name" to dst node. // 2. For data edges between different XLA computations, if either src or dst @@ -146,26 +158,53 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses the graph for encapsulation. This function reverts what -// `PreprocessForEncapsulation` did. It will perform the following operations in -// order: +// Postprocesses edges between different XLA clusters for encapsulation. This +// function reverts what `PreprocessForEncapsulation` did. It will perform the +// following operations in order: // // 1. Remove Placeholder nodes between outside compilation and host computation // (created in `PreprocessForEncapsulation` step 3). // 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1c) and control edges between outside -// compilation and host computation (marked by `PreprocessForEncapsulation` -// step 1d). -// 3b. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1b). -// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are -// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`. +// 3a. Reconnect control edges between outside compilation and another XLA +// computation (marked by `PreprocessForEncapsulation` step 1a). +// 3b. Reconnect control edges between different outside compilations (marked by +// `PreprocessForEncapsulation` step 1b). +// 3c. Reconnect control edges between outside compilation and host computation +// (marked by `PreprocessForEncapsulation` step 1c). Status PostprocessForEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters); +// Preprocesses edges within the same XLA cluster. It will perform the following +// operations in order: +// +// 0. Remove edges from source node to outside compilation nodes, and edges +// from outside compilation nodes to sink node. +// 1a. For edges between different outside compilation clusters, remove the edge +// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node +// name" to dst node. +// 1b. For control edges between outside compilation and its XLA computation, +// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the +// outside compilation node. +// 2. For data edges between different outside compilations, remove the edge +// and create a Placeholder node as dst node's input. +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); + +// Postprocesses edges within the same XLA cluster. This function reverts what +// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between different outside compilations (created +// in `PreprocessEdgesBetweenOutsideCompilations` step 2). +// 2a. Reconnect control edges between different outside compilations (marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1a). +// Notice that control edges marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. +// They are handled in `RewriteOutsideCompilationSubgraphFn`. +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 7255df3112916b7abcc98ff8204efc8c02209b13..3b8b49cb92f3e453883a8e64e12ce3748a5173f6 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { identity4_node->AddAttr("_xla", "1"); identity4_node->AddAttr("_oc", "0"); identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and its XLA computation. - g.AddControlEdge(add_node, identity0_node); - g.AddControlEdge(identity0_node, identity1_node); - // Case 1b: control edges between outside compilation and another XLA + // Case 1a: control edges between outside compilation and another XLA // computation. g.AddControlEdge(identity0_node, identity3_node); g.AddControlEdge(identity1_node, identity4_node); - // Case 1c: control edges between different outside compilations. + // Case 1b: control edges between different outside compilations. g.AddControlEdge(identity0_node, identity4_node); - // Case 1d: control edges between outside compilation and host computation. + // Case 1c: control edges between outside compilation and host computation. g.AddControlEdge(const0_node, identity0_node); g.AddControlEdge(identity0_node, identity2_node); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - // Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the - // outside compilation node. - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedFromXlaComputationAttrName)); - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedToXlaComputationAttrName)); - // Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name" + // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" // to the outside compilation node. std::vector attr; TF_CHECK_OK(GetNodeAttr(identity0_node->def(), @@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { kXlaConnectedFromOtherXlaComputationAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. + // Case 1b: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity4_node->def(), kXlaControlDependenciesAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "identity0"); - // Case 1d: add attr "_xla_control_deps = src node name" to dst node. + // Case 1c: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity0_node->def(), kXlaControlDependenciesAttrName, &attr)); @@ -162,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { TEST(PreprocessForEncapsulationTest, DataEdges) { // Build the graph: // "const_0" and "const_1" in host computation + // "identityn0" = ("const_0", "const_1") in host computation 0 // "add0" = "const_0" + "const_1" in XLA computation 0 // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 // "identity0" = "add1" in XLA computation 0 // "add2" = "add1" + "identity0" in host computation // "add3" = "add1" + "add2" in XLA computation 1 - // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1 + // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 + // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & + // outside compilation 0 + // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & + // outside compilation 0 // "identity1" = "add4" in XLA computation 1 // "identity2" = "identity1" in host computation tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); + auto identityn0 = + ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); + Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); + auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), + {identityn0[0], identityn0[1]}); Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); Graph g(OpRegistry::Global()); @@ -189,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], *identity0_node = node_index["identity0"], *add3_node = node_index["add3"], *add4_node = node_index["add4"], + *add5_node = node_index["add5"], + *identityn1_node = node_index["identityn_1"], *identity1_node = node_index["identity1"]; add0_node->AddAttr("_xla", "0"); add1_node->AddAttr("_xla", "0"); @@ -197,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { add3_node->AddAttr("_xla", "1"); add4_node->AddAttr("_xla", "1"); add4_node->AddAttr("_oc", "0"); + add5_node->AddAttr("_xla", "1"); + add5_node->AddAttr("_oc", "0"); + identityn1_node->AddAttr("_xla", "1"); + identityn1_node->AddAttr("_oc", "0"); identity1_node->AddAttr("_xla", "1"); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); @@ -214,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { EXPECT_NE(bridge_identity0_add4, nullptr); // Step 3: add placeholder for edges between host computation and outside // compilation. - EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder"); - Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"]; + EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); + Node *add1_oc_to_host_placeholder = + node_index["add1_oc_to_host_placeholder_0"]; TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, &str)); EXPECT_EQ(str, "add1"); @@ -226,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) { add4_node = node_index["add4"]; ASSERT_NE(add4_node, nullptr); EXPECT_EQ(add4_node->def().input(0), - "bridge_identity0_add4_host_to_oc_placeholder"); + "bridge_identity0_add4_host_to_oc_placeholder_0"); Node *identity0_host_to_oc_placeholder = - node_index["bridge_identity0_add4_host_to_oc_placeholder"]; + node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &str)); EXPECT_EQ(str, "bridge_identity0_add4"); TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), kHostToOutsideCompilationSrcOutputAttrName, &i)); EXPECT_EQ(i, 0); + + // Check different placeholder nodes are created for different src_output. + Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], + *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; + EXPECT_NE(placeholder0, nullptr); + EXPECT_NE(placeholder1, nullptr); + // Check we only have 2 placeholder nodes created for "identityn_0". + int placeholder_count = 0; + for (Node *n : g.nodes()) { + if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { + string attr; + TF_CHECK_OK(GetNodeAttr( + n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); + if (attr == "identityn_0") { + ++placeholder_count; + } + } + } + EXPECT_EQ(placeholder_count, 2); } TEST(PostprocessForEncapsulationTest, ControlEdges) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 2ce6fa73fc448ca83fa392aa909cb385453eb8b6..d334100aa4a915a87fb05d371e0e3379a7ee05f2 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( - "Undeclared output of XLA computation. A common cause of this error " - "is variable initializers that depend on the XLA computation. Edge: ", + "Undeclared output of XLA computation. Some common causes of this " + "error are: 1) variable initializers that depend on the XLA " + "computation; 2) gradient computations that depend on the XLA " + "computation, which can be mitigated by moving gradient computations " + "inside XLA computation. Offending edge: ", e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", e->dst_input()); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 8b3587c5087a0651c466f53f3709ba21e75dd273..e3c7e2f89be9b37b51a633dabb099969c181013f 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( // replace this node with compilation result node. // 3) all outside compilation graphs. Status ConstructHostGraph( - const string& xla_cluster_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { host_graph->reset(new Graph(fld)); @@ -476,6 +476,10 @@ Status ConstructHostGraph( host_graph->get(), std::unordered_set{(*host_graph)->sink_node()}); + // Postprocess edges between different outside compilations. + TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( + host_graph->get(), outside_compilation_attr_name)); + if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", @@ -801,6 +805,11 @@ Status ExtractOutsideCompilationForFunction( }, &fbody)); std::unique_ptr fbody_deleter(fbody); + + // Preprocess edges between different outside compilations. They will be + // restored in `ConstructHostGraph()`. + TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( + fbody->graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_before_", func_name), @@ -860,8 +869,9 @@ Status ExtractOutsideCompilationForFunction( // Construct host graph. if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR(ConstructHostGraph( - xla_cluster_name, outside_compilation_host_graphs, fld, host_graph)); + TF_RETURN_IF_ERROR( + ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph)); } // Remove the outside compilation graphs from function library. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index c5bd64f004ef98853955372680277e04c16bdc9e..bff956100da661b679b4557fce53671e6cef88c5 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes)); EXPECT_EQ(shapes.size(), 1); EXPECT_EQ(shapes[0].dim_size(), 1); - // Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a - // non-empty value, and "1" should have an empty value. + // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have + // empty values. string shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graph, ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); EXPECT_EQ(shape_inference_graph, ""); // Check `shape_inference_graphs`. - EXPECT_EQ(shape_inference_graphs.size(), 1); - EXPECT_EQ(shape_inference_graphs[0], - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graphs.size(), 0); // Check `host_graph`: verify we have key placeholder and sequencer. Node *key_placeholder = nullptr, *sequencer = nullptr; @@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { send_recv_nodes.push_back(n); } } - EXPECT_EQ(num_send_from_host, 2); - EXPECT_EQ(num_recv_at_host, 2); + EXPECT_EQ(num_send_from_host, 1); + EXPECT_EQ(num_recv_at_host, 1); for (Node *n : send_recv_nodes) { Node *input_node; TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..98e344b3a080aa8aab27cd41564a90427bac151e --- /dev/null +++ b/tensorflow/compiler/jit/flags.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +BuildXlaOpsPassFlags* build_ops_flags; +DumpGraphFlags* dump_graph_flags; +MarkForCompilationPassFlags* mark_for_compilation_flags; +XlaDeviceFlags* device_flags; +XlaOpsCommonFlags* ops_flags; + +std::vector* flag_list; +std::once_flag flags_init; + +void AppendDumpGraphFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, + "Path prefix to which graphs dumped during debugging should be " + "written."), + }; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", + &mark_for_compilation_flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", + &mark_for_compilation_flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", + &mark_for_compilation_flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", + &mark_for_compilation_flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", + &mark_for_compilation_flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", + &mark_for_compilation_flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AllocateAndParseFlags() { + build_ops_flags = new BuildXlaOpsPassFlags; + build_ops_flags->tf_xla_enable_lazy_compilation = true; + + dump_graph_flags = new DumpGraphFlags; + dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; + + mark_for_compilation_flags = new MarkForCompilationPassFlags; + mark_for_compilation_flags->tf_xla_auto_jit = 0; + mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_max_cluster_size = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_clustering_debug = false; + mark_for_compilation_flags->tf_xla_cpu_global_jit = false; + mark_for_compilation_flags->tf_xla_clustering_fuel = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_fusion_only = false; + + device_flags = new XlaDeviceFlags; + device_flags->tf_xla_compile_on_demand = false; + + ops_flags = new XlaOpsCommonFlags; + ops_flags->tf_xla_always_defer_compilation = false; + + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + + Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + + Flag("tf_xla_always_defer_compilation", + &ops_flags->tf_xla_always_defer_compilation, ""), + }); + AppendDumpGraphFlagsInternal(flag_list); + AppendMarkForCompilationPassFlagsInternal(flag_list); + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *build_ops_flags; +} + +DumpGraphFlags* GetDumpGraphFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return dump_graph_flags; +} + +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return mark_for_compilation_flags; +} + +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return device_flags; +} + +const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *ops_flags; +} + +void AppendMarkForCompilationPassFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendMarkForCompilationPassFlagsInternal(flag_list); +} + +void AppendDumpGraphFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendDumpGraphFlagsInternal(flag_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/flags.h similarity index 57% rename from tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h rename to tensorflow/compiler/jit/flags.h index 79b47357a179d2d9e0d1b6bf9c9f814288bcd5e1..5ddea588eef5270880d91623dc05893da265960a 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. +#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_FLAGS_H_ #include @@ -24,15 +22,8 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags( - std::vector* flag_list); -// The values of flags associated with the XLA bridge's -// mark_for_compilation_pass module. +// Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { int32 tf_xla_auto_jit; // Control compilation of operators into XLA // computations on CPU and GPU devices. 0 = use @@ -57,12 +48,56 @@ struct MarkForCompilationPassFlags { // only using XLA. }; -// Return a pointer to the MarkForCompilationPassFlags struct; +// Flags associated with the XLA bridge's xla_device module. +struct XlaDeviceFlags { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +}; + +// Flags common to the _Xla* ops and their kernels. +struct XlaOpsCommonFlags { + // If true, _XlaCompile always refuses to compile the cluster, which means the + // XLA clusters always run in the TF executor. Defaults to false. + bool tf_xla_always_defer_compilation; +}; + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to true. + bool tf_xla_enable_lazy_compilation; +}; + +// Flags for the XLA bridge's dump_graph module. +struct DumpGraphFlags { + // Path prefix to which graphs dumped during debugging should be written. + string tf_dump_graph_prefix; +}; + +// Return a pointer to the DumpGraphFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. + +// Getters for flags structs defined above. The first call to any of these +// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer +// always return the same pointer. MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); +const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); +DumpGraphFlags* GetDumpGraphFlags(); + +// Appends the flag definitions associated with +// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. +// +// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); +void AppendDumpGraphFlags(std::vector* flag_list); -} // namespace legacy_flags } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index d984ca15cb722821b2a466a90387a29cbc1d1097..ce53f70b79d97ab087fefe542920b33f883632a2 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -208,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope, DCHECK_EQ(slice_size.back().type(), DT_INT64); } - *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + // Trivial ConcatV2 nodes (with exactly one input) are disallowed. + *size = + slice_size.size() == 1 + ? slice_size[0] + : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + ops::Const(host_scope.WithOpName("concat_axis"), 0)); return Status::OK(); } @@ -242,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .WithOpName("static_shaped_slice"), slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) .node(); + + TF_RETURN_IF_ERROR(main_scope.status()); + std::vector compile_time_const_inputs; compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, @@ -284,49 +291,45 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// If `n` is a slice we can rewrite to have a static shape (i.e. have the output -// shape only depend on the "size" input) then returns the a SliceInputs -// representing the inputs to `n`. Otherwise returns nullopt. -StatusOrOptional IsRewritableSlice(Node* n) { +// Return true if `n` is a slice we can rewrite to have a static shape +// (i.e. have the output shape only depend on the "size" input). +xla::StatusOr IsRewritableSlice(Node* n) { if (n->type_string() != "Slice") { - return {absl::nullopt}; + return false; } if (!GetXlaClusterForNode(*n).has_value()) { // There is no need to change slice ops outside XLA clusters. - return {absl::nullopt}; + return false; } TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, GetSliceInputs(n)); if (!slice_inputs.has_value()) { - return {absl::nullopt}; + return false; } // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); - if (!slice_is_ok) { - return {absl::nullopt}; - } - - return slice_inputs; + return absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); } Status FindAndRewriteSlices(Graph* g, bool* changed) { - std::vector> slices_to_rewrite; + std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, - IsRewritableSlice(n)); - if (slice_inputs.has_value()) { - slices_to_rewrite.push_back({n, std::move(*slice_inputs)}); + TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + if (is_rewritable) { + slices_to_rewrite.push_back(n); } } - for (const auto& pair : slices_to_rewrite) { - TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second, - *GetXlaClusterForNode(*pair.first))); + for (Node* n : slices_to_rewrite) { + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + TF_RET_CHECK(slice_inputs.has_value()); + TF_RETURN_IF_ERROR( + RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n))); } if (!slices_to_rewrite.empty()) { @@ -342,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", **options.graph, options.flib_def); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 0f6f612e967035f6af3e4aff2a499d5cedd018af..a2f1b831ad7605237e23c15cc43b337e06265553 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. namespace tensorflow { namespace { +using ::testing::_; using testing::matchers::AssignedDevice; using testing::matchers::Attr; using testing::matchers::Const; @@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { EXPECT_THAT(static_shaped_slice, m_dynamic_slice); } +TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + EXPECT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2"))))); +} + TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { Scope root = Scope::NewRootScope() .ExitOnError() @@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); } +int64 ToInt64(int v) { return static_cast(v); } + TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; - Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); Output size = - ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)}); + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)}); Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); std::unique_ptr result; @@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) { Attr(kXlaCompileTimeConstantInputsAttr))))); } +TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Const(root.WithOpName("size"), {}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr), + Inputs(_, _, Out(NodeWith(Name(size.node()->name())))))); +} + TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; + auto ToInt64 = [](int v) { return static_cast(v); }; Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); @@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); Output size = - ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}}); + ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}}); TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); std::unique_ptr result; @@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Not(Contains(NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr))))); } + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a); + + Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200}); + Output slice_with_slice_input = ops::Slice( + root.WithOpName("slice_with_slice_input"), slice, begin, size_b); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_input/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(Out(NodeWith( + Op("Slice"), + Name("slice/static_shaped_slice/static_shaped_slice"))), + _, _))); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input_float = + ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT); + Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64); + + Output begin_begin = + ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32); + Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1}); + Output begin = + ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size); + + Output size = + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)}); + Output slice_with_slice_begin = ops::Slice( + root.WithOpName("slice_with_slice_begin"), input_float, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_begin/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(_, + Out(NodeWith( + Op("Slice"), + Name("begin/static_shaped_slice/static_shaped_slice"))), + _))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 830db9ebdd92608c375ad778eced833e26729325..0583774714c6db7a2fa515fc8a0d304e1898db97 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -12,10 +12,10 @@ cc_library( hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", - "//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 055de7afcc538a1a1183f3687d998a5b2211c887..ad71df5a694a5f8da94675049df1062a7edb6253 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -418,7 +418,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { cannot_compile_cluster = cannot_compile_cluster_; } - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || cannot_compile_cluster) { executable = nullptr; } else { diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD deleted file mode 100644 index 5fa6c85f06f863f5d18bc4939ffa0ae820d222bd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -# Legacy command line flags for the XLA bridge libraries. - -# Please do not add more flags to this package. - -# The XLA bridge libraries were written in an environment that allowed -# command-line flags to be scattered freely throughout the libraries. This -# model, while initially convenient, leads to a proliferation in unused command -# line flags in tests and binaries, and serious problems in servers, where one -# might wish parameters to be different in independent RPC calls to the same -# routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -cc_library( - name = "mark_for_compilation_pass_flags", - srcs = ["mark_for_compilation_pass_flags.cc"], - hdrs = ["mark_for_compilation_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_device_flags", - srcs = ["xla_device_flags.cc"], - hdrs = ["xla_device_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "build_xla_ops_pass_flags", - srcs = ["build_xla_ops_pass_flags.cc"], - hdrs = ["build_xla_ops_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_ops_common_flags", - srcs = ["xla_ops_common_flags.cc"], - hdrs = ["xla_ops_common_flags.h"], - deps = - [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc deleted file mode 100644 index 961c17c17eac891261530ef25baaa50f8496c331..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT - -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { -namespace { - -BuildXlaOpsPassFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new BuildXlaOpsPassFlags; - flags->tf_xla_enable_lazy_compilation = true; - flag_list = new std::vector({ - Flag("tf_xla_enable_lazy_compilation", - &flags->tf_xla_enable_lazy_compilation, ""), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -} // namespace - -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h deleted file mode 100644 index 9aa5cf64d6db56ae36875ca08d2ae88c73604733..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags for the build_xla_ops pass. -struct BuildXlaOpsPassFlags { - // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. - // Defaults to true. - bool tf_xla_enable_lazy_compilation; -}; - -// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc deleted file mode 100644 index bad306e0b0a3061ba13dc69c08066c642667a2b9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static MarkForCompilationPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new MarkForCompilationPassFlags; - flags->tf_xla_auto_jit = 0; - flags->tf_xla_min_cluster_size = 2; - flags->tf_xla_max_cluster_size = std::numeric_limits::max(); - flags->tf_xla_clustering_debug = false; - flags->tf_xla_cpu_global_jit = false; - flags->tf_xla_clustering_fuel = std::numeric_limits::max(); - flags->tf_xla_fusion_only = false; - flag_list = new std::vector( - {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions."), - Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, - "Places an artificial limit on the number of ops marked as " - "eligible for clustering."), - Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, - "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}); - xla::ParseFlagsFromEnv(*flag_list); - - if (VLOG_IS_ON(1)) { - VLOG(1) << "Parsed MarkForCompilationPassFlags:"; - VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit; - VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size; - VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size; - VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug; - VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; - VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel; - VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only; - } -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the MarkForCompilationPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc deleted file mode 100644 index 76b80d3034c8a13a1ddf1afe548d5c3d9c7b2cec..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's xla_device module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static XlaDeviceFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new XlaDeviceFlags; - flags->tf_xla_compile_on_demand = false; - flag_list = new std::vector({ - Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, - "Switch a device into 'on-demand' mode, where instead of " - "autoclustering ops are compiled one by one just-in-time."), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h deleted file mode 100644 index 27b22121ac1e089bd5d5a494e1e3fb60b05bc76d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ - -// Legacy flags for the XLA bridge's xla_device module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// The values of flags associated with the XLA bridge's -// xla_device module. -typedef struct { - // Switch the CPU device into "on-demand" mode, where instead of - // autoclustering ops are compiled one by one just-in-time. - // Enabling this mode by a legacy flag is a temporary mechanism. When this - // feature is battle-tested, we will switch this to be a session option. - bool tf_xla_compile_on_demand; -} XlaDeviceFlags; - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc deleted file mode 100644 index 1443d48a734c0a44c1cd91d8d1218bdbed7f765c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -XlaOpsCommonFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new XlaOpsCommonFlags; - flags->tf_xla_always_defer_compilation = false; - flag_list = new std::vector({ - Flag("tf_xla_always_defer_compilation", - &flags->tf_xla_always_defer_compilation, ""), - }); - xla::ParseFlagsFromEnv(*flag_list); - - if (VLOG_IS_ON(1)) { - VLOG(1) << "Parsed XlaOpsCommonFlags:"; - VLOG(1) << " tf_xla_always_defer_compilation = " - << flags->tf_xla_always_defer_compilation; - } -} - -const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h deleted file mode 100644 index 7c5c1818ef2d1dcf38c324a2c926db9c4bfa8ef5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags common to the _Xla* ops and their kernels. -struct XlaOpsCommonFlags { - // If true, _XlaCompile always refuses to compile the cluster, which means the - // XLA clusters always run in the TF executor. Defaults to false. - bool tf_xla_always_defer_compilation; -}; - -// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 70033cae0afacb6a25598ee1abf2aeb2721e7496..6618e3a58ab7b6374ed775cd6e4e18a6a4975588 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -72,6 +72,11 @@ struct OperationFilter { // to resort to a dummy implementation. Currently Assert and CheckNumerics ops // have dummy XLA implementations. bool allow_dummy_ops; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant; }; bool IsDummyImplOp(absl::string_view op_name) { @@ -81,7 +86,13 @@ bool IsDummyImplOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || - op_name == "TruncatedNormal"; + op_name == "TruncatedNormal" || op_name == "Multinomial"; +} + +bool OpProducesOrConsumesVariant(const Node& node) { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); } bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -246,6 +257,10 @@ bool IsCompilableCall(const NodeDef& call_def, if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { return false; } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, lib_runtime)) { @@ -427,8 +442,7 @@ Status FindCompilationCandidates( BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, &compile_time_const_nodes)); - int64& fuel = - legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the @@ -471,16 +485,15 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); + bool always_auto_cluster = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + OperationFilter op_filter; op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); - op_filter.allow_control_trigger = - (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); - op_filter.allow_dummy_ops = (registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_stateful_rng_ops = always_auto_cluster; + op_filter.allow_control_trigger = always_auto_cluster; + op_filter.allow_dummy_ops = always_auto_cluster; + op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster; if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, @@ -504,6 +517,12 @@ Status FindCompilationCandidates( << node->type_string() << ")"; continue; } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + VLOG(2) << "Rejecting " << node->name() + << ": produces or consumes DT_VARIANT"; + continue; + } if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { @@ -607,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( // To set compilation to be on by default, change the following line. global_jit_level = OptimizerOptions::OFF; } - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_auto_jit == -1 || (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides @@ -641,6 +659,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { op_filter.allow_stateful_rng_ops = true; op_filter.allow_control_trigger = true; op_filter.allow_dummy_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } @@ -651,8 +670,7 @@ Status MarkForCompilationPass::Run( // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool fusion_only = flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; @@ -953,8 +971,7 @@ Status MarkForCompilationPass::RunImpl( OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 24d78c077268f83cebbdafddc1a658ae8dc6b8d8..bf2c5508ea9e987e80093f4c2e15d3ff5191126f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -1147,5 +1148,80 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) { EXPECT_EQ(clusters["test/check"], ""); } +TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output dummy_input = + ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64); + Output variant_input = + ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT); + + // Create one more node so that we don't avoid creating a cluster solely + // because it would be trivial. + Output dummy_cast = + ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32); + + Output tensor_list_element_shape = ops::TensorListElementShape( + root.WithOpName("test/tensor_list_element_shape"), variant_input, + DT_INT32); + + root.graph()->AddControlEdge(dummy_cast.node(), + tensor_list_element_shape.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); +} + +TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(xla_cpu_device); + } + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/tensor_list_reserve"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index d56d0f8ccfcdab40003be38059228cb255921b64..64a3301745790132fe3149bf8fb52d6c45ecc3c1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -34,15 +34,9 @@ namespace tensorflow { // // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // make this more direct, but probably not worth it solely for this test. - std::vector devices; + std::vector> devices; TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); - auto delete_devices = gtl::MakeCleanup([&] { - for (Device* d : devices) { - delete d; - } - }); - GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index f72224545b25bc7100e0b6788e6fbf0a7ca63dad..64409d9334751e0edfce9091a4e5697dd2c712c5 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -18,3 +18,9 @@ tf_gen_op_wrapper_py( out = "xla_ops.py", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) + +py_library( + name = "xla_ops_grad", + srcs = ["xla_ops_grad.py"], + deps = ["//tensorflow/python:framework_ops"], +) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/compiler/jit/ops/xla_ops_grad.py similarity index 62% rename from tensorflow/contrib/estimator/python/estimator/dnn.py rename to tensorflow/compiler/jit/ops/xla_ops_grad.py index 10f657df8de64cc96f0cf04f434a77df66629dca..2d31d8dc714307a48932d061fb1af643940a0872 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/compiler/jit/ops/xla_ops_grad.py @@ -1,3 +1,4 @@ +"""Gradients for XLA ops.""" # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,21 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""dnn python module. - -Importing from tensorflow.python.estimator is unsupported -and will soon break! -""" -# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow_estimator.contrib.estimator.python.estimator import dnn +from tensorflow.python.framework import ops -# Include attrs that start with single underscore. -_HAS_DYNAMIC_ATTRIBUTES = True -dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')] -from tensorflow_estimator.contrib.estimator.python.estimator.dnn import * +@ops.RegisterGradient("XlaClusterOutput") +def _XlaClusterOutputGrad(_, grad): + del grad # unused + raise RuntimeError("Gradient computation of graph in xla.compile() is " + "prohibited because it can cause performance degradation." + "Please move gradient computation inside xla.compile().") diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 36b345ecbff8d5f6ba3c241b9e164f677236c20d..42ea3926e16ae791dbe1bede3b8742383db7667c 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -26,6 +26,10 @@ limitations under the License. namespace tensorflow { namespace { + +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } + +namespace reduce_device_to_host_copies { Status FindNodesToDecluster(const Graph& graph, absl::flat_hash_set* result, absl::Span post_order) { @@ -140,8 +144,6 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } - // Clones nodes to outside their cluster to avoid device-to-host copies. For // instance, converts this: // @@ -168,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } // where the ===> arrow has a hostmem source and destination and would entail a // device to host copy if the source and destination were not in the same XLA // cluster. -Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been @@ -206,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { return Status::OK(); } +} // namespace reduce_device_to_host_copies +namespace reduce_recompilation { bool IsIntraClusterEdge(const Edge& edge) { absl::optional src_cluster_name = GetXlaClusterForNode(*edge.src()); @@ -269,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) { // regress performance in any significant manner. We will have to revisit this // algorith with a more complex cost model if this assumption turns out to be // incorrect. -Status DeclusterNodesToReduceRecompilations(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { std::vector compile_time_const_nodes(graph->num_node_ids()); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); @@ -322,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) { return Status::OK(); } - +} // namespace reduce_recompilation } // namespace Status PartiallyDeclusterPass::Run( @@ -334,8 +338,9 @@ Status PartiallyDeclusterPass::Run( Graph* graph = options.graph->get(); - TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); - TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + TF_RETURN_IF_ERROR( + reduce_device_to_host_copies::PartiallyDeclusterGraph(graph)); + TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 1fc5da5071f7aa6f6dd6636aacd60e33c12431a6..38a54cc5efae35ad77b6dc8039c653e920cfc071 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(s.ToGraph(graph.get())); // This is needed to register the XLA_GPU device. - std::vector devices; + std::vector> devices; TF_ASSERT_OK(DeviceFactory::AddDevices( SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); @@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(PartiallyDecluster(&graph)); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); - - for (Device* d : devices) { - delete d; - } } TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h deleted file mode 100644 index 7c8c04152d2f3a0fd46711df24756b7e68b967ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ -#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ - -#include -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// A thread-safe, first-in-first-out queue. -template -class ProducerConsumerQueue { - public: - ProducerConsumerQueue() - : capacity_(std::numeric_limits::max()) {} - ~ProducerConsumerQueue() = default; - - // Wait until the queue is non-full, then append a copy of v. - void Put(const T &v); - - // Wait until the queue is non-empty, then remove and return the head value. - T Get(); - - // If the queue is non-empty, remove the head value, placing it in *pv, and - // return true; otherwise return false. - bool TryGet(T *pv); - - // Set the capacity of the queue; the queue is full whenever count() >= - // capacity(). The initial value is the maximum size_t. Requires size > 0. - void set_capacity(std::size_t size); - - // Return the capacity of the queue. - std::size_t capacity() const; - - // Return the number of elements in the queue. - std::size_t count() const; - - // Implementation details follow. Clients should ignore. - private: - mutable tensorflow::mutex mu_; // protects all fields below - tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); - tensorflow::condition_variable non_full_ GUARDED_BY(mu_); - std::size_t capacity_ GUARDED_BY(mu_); - std::deque queue_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); -}; - -// ------------------------------------------------------ -// Implementation details follow. Clients should ignore. - -// Wait until the queue is non-full, then append a copy of v. -template -void ProducerConsumerQueue::Put(const T &v) { - mutex_lock lock(mu_); - while (queue_.size() >= capacity_) { - non_full_.wait(lock); - } - queue_.push_back(v); - non_empty_.notify_one(); -} - -// Wait until the queue is non-empty, then remove and return the head value. -template -T ProducerConsumerQueue::Get() { - mutex_lock lock(mu_); - while (queue_.empty()) { - non_empty_.wait(lock); - } - non_full_.notify_one(); - T result_value = queue_.front(); - queue_.pop_front(); - return result_value; -} - -// If the queue is non-empty, remove the head value, placing it in *pv, and -// return true; otherwise return false. -template -bool ProducerConsumerQueue::TryGet(T *pv) { - mutex_lock lock(mu_); - bool got_element = !queue_.empty(); - if (got_element) { - non_full_.notify_one(); - *pv = queue_.front(); - queue_.pop_front(); - } - return got_element; -} - -// Set the capacity of the queue; the queue is full whenever count() >= -// capacity(). The initial value is the maximum size_t. Requires size > 0. -template -void ProducerConsumerQueue::set_capacity(std::size_t size) { - mutex_lock lock(mu_); - CHECK_NE(size, 0); - capacity_ = size; - non_full_.notify_all(); -} - -// Return the capacity of the queue. -template -std::size_t ProducerConsumerQueue::capacity() const { - mutex_lock lock(mu_); - std::size_t max_elements = capacity_; - return max_elements; -} - -// Return the number of elements in the queue. -template -std::size_t ProducerConsumerQueue::count() const { - mutex_lock lock(mu_); - std::size_t num_elements = queue_.size(); - return num_elements; -} -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc deleted file mode 100644 index f61260c6e52756ee039829afdc7452f5f760c221..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/producer_consumer_queue.h" - -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -typedef ProducerConsumerQueue IntQueue; - -// Insert integers between low inclusive and high exclusive into q. -void PushRange(IntQueue *q, int low, int high) { - while (low != high) { - q->Put(low); - VLOG(2) << "Pushing " << low; - ++low; - } -} - -// Push the numbers between 0 and 999 inclusive from several threads in the -// pool. -void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { - VLOG(1) << "Adding 20-36"; - pool->Schedule([queue] { PushRange(queue, 20, 36); }); - VLOG(1) << "Adding 7-20"; - pool->Schedule([queue] { PushRange(queue, 7, 20); }); - VLOG(1) << "Adding 36-501"; - pool->Schedule([queue] { PushRange(queue, 36, 501); }); - VLOG(1) << "Adding 501-1000"; - pool->Schedule([queue] { PushRange(queue, 501, 1000); }); - VLOG(1) << "Adding 0-5"; - pool->Schedule([queue] { PushRange(queue, 0, 5); }); - VLOG(1) << "Adding 5-7"; - pool->Schedule([queue] { PushRange(queue, 5, 7); }); -} - -// Pop elements from queue using Get(). Make sure that exactly elements -// were present and their values are all integers between 0 and high-1 -// inclusive. -void GetRange(IntQueue *queue, int high) { - VLOG(1) << "Testing Wait"; - std::vector results; - for (int i = 0; i != high; ++i) { - int r = queue->Get(); - VLOG(2) << "Waited and got " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK(results[i] == i); - } -} - -// Pop elements from queue using TryGet(). Make sure that exactly -// elements were present and their values are all integers between 0 and high-1 -// inclusive. -void TryGetRange(IntQueue *queue, int high) { - std::vector results; - // Give up if we don't get all the elements back from the queue - // in 10 seconds. - int timeout = 10; - int r; - for (int i = 0; i != high; ++i) { - while (!queue->TryGet(&r)) { - if (!timeout--) { - LOG(FATAL) << "Can't find all elements in the queue"; - } - VLOG(1) << "Sleeping for a second..."; - sleep(1); - } - VLOG(2) << "Popped " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - CHECK(!queue->TryGet(&r)); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK_EQ(i, results[i]); - } -} - -const int kNumThreads = 15; - -TEST(ProducerConsumerQueue, GetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - GetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, TryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - TryGetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, ParallelGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { GetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -TEST(ProducerConsumerQueue, ParallelTryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1fe612d43d10030675cf307b109e4dcc89cb2d79..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -142,11 +142,22 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(ctx->allocate_temp( device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); Notification n; + Status status; ctx->op_device_context()->CopyDeviceTensorToCPU( &device_tensor, "ConstantArgument", reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); + [&](Status s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } constant_arguments[i] = host_tensor; } } @@ -189,6 +200,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_arguments, variable_args, ctx, &args)); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 116e0756036e722c13f27579aa0e0876d2e846a7..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,8 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -31,13 +31,13 @@ namespace tensorflow { class XlaCpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, - const string& name_prefix, - std::vector* devices) { - legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); +Status XlaCpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { + XlaDeviceFlags* flags = GetXlaDeviceFlags(); bool compile_on_demand = flags->tf_xla_compile_on_demand; XlaOpRegistry::DeviceRegistration registration; @@ -64,7 +64,18 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.use_multiple_streams = false; auto device = absl::make_unique(session_options, options); - devices->push_back(device.release()); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 5c1b55cb57f58387086ab9eaf924d0beffb43e18..4201ff91a89b1bee370e6a43337c51abe3bf974a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } if (device_context_) { device_context_->Unref(); } @@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; + tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); std::shared_ptr stream; { mutex_lock lock(mu_); @@ -391,13 +395,46 @@ Status XlaDevice::Sync() { } if (!stream) return Status::OK(); - if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + Status status = stream->BlockHostUntilDone(); + { + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + } + TF_RETURN_IF_ERROR(status); + if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); } VLOG(1) << "XlaDevice::Sync completed"; return Status::OK(); } +void XlaDevice::Sync(const DoneCallback& done) { + VLOG(1) << "XlaDevice::Sync (asynchronous)"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) { + done(Status::OK()); + return; + } + + stream->ThenEnqueueOnBackgroundThread( + [this, stream, done](se::StreamExecutor*) { + tracing::ScopedActivity activity("XlaDevice::Sync::Callback", + /*is_expensive=*/true); + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + done(stream->ok() ? Status::OK() + : errors::Internal("XlaDevice::Sync() failed.")); + }); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -441,6 +478,49 @@ bool XlaDevice::RequiresSyncOnCompletion() const { return sync_on_completion_; } +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice* device) + : device_(device) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { + if (device_) { + mutex_lock lock(device_->mu_); + --device_->outstanding_asynchronous_operations_; + device_->outstanding_asynchronous_operations_cv_.notify_all(); + } +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + const XlaDevice::AsynchronousOperationHandle& other) + : device_(other.device_) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice::AsynchronousOperationHandle&& other) + : device_(other.device_) { + other.device_ = nullptr; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(const XlaDevice::AsynchronousOperationHandle& other) { + device_ = other.device_; + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; + return *this; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(XlaDevice::AsynchronousOperationHandle&& other) { + device_ = other.device_; + other.device_ = nullptr; + return *this; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 49f53b477ef5508a23812453cb61e29a8d8b9379..c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -135,6 +135,7 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; + void Sync(const DoneCallback& done) override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -164,7 +165,30 @@ class XlaDevice : public LocalDevice { bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // A simple RAII handle. On construction the device's + // outstanding_asynchronous_operations_ field is incremented; on destruction + // it is decremented. + class AsynchronousOperationHandle { + public: + AsynchronousOperationHandle(XlaDevice* device); + ~AsynchronousOperationHandle(); + AsynchronousOperationHandle(const AsynchronousOperationHandle& other); + AsynchronousOperationHandle(AsynchronousOperationHandle&& other); + AsynchronousOperationHandle& operator=( + const AsynchronousOperationHandle& other); + AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + + private: + XlaDevice* device_ = nullptr; + }; + + AsynchronousOperationHandle CreateAsynchronousOperationHandle() { + return AsynchronousOperationHandle(this); + } + private: + friend class AsynchronousOperationHandle; + xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -227,6 +251,11 @@ class XlaDevice : public LocalDevice { // True if the device requires XlaDevice::Sync to be called on completion // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = false; + + // Count of outstanding asynchronous operations which must be zero on Sync() + // completion. + int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; + condition_variable outstanding_asynchronous_operations_cv_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index adf0f994b84d9fbf918a5b2478aa7d106853e038..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -203,6 +203,8 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 441970169581d53e0d8683b98d26712445b170ea..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,10 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -29,12 +32,12 @@ namespace tensorflow { class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, - const string& name_prefix, - std::vector* devices) { +Status XlaGpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.autoclustering_policy = @@ -52,8 +55,35 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); return Status::OK(); } - - for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + string allowed_gpus = + session_options.config.gpu_options().visible_device_list(); + std::set gpu_ids; + int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); + if (allowed_gpus.empty()) { + for (int i = 0; i < num_visible_devices; ++i) { + gpu_ids.insert(i); + } + } else { + // For loop below is copied from gpu/gpu_device.cc. It validates + // the visible_device_list and populates gpu_ids set. + const std::vector visible_devices = + absl::StrSplit(allowed_gpus, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); + } + if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { + return errors::InvalidArgument( + "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, + "' but visible device count is ", num_visible_devices); + } + gpu_ids.insert(platform_gpu_id); + } + } + for (int i : gpu_ids) { XlaDevice::Options options; options.platform = platform.ValueOrDie(); options.device_name_prefix = name_prefix; @@ -70,7 +100,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, return status; } - devices->push_back(device.release()); + devices->push_back(std::move(device)); } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index e828bae865d630bd40f227943cdabb2d8d95ca48..4007309ed1c57b663dca5bac0df11260bf1327f3 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -33,12 +33,12 @@ constexpr std::array kExecAllTypes = { class XlaInterpreterDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; Status XlaInterpreterDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, - std::vector* devices) { + std::vector>* devices) { static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; @@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( options.device_ordinal = 0; options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); - devices->push_back(device.release()); + devices->push_back(absl::make_unique(session_options, options)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 437db019a0eabe66417725148d8b121842e90479..554227f09de0ab4d9e07f199b957657f3121ff06 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -199,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer { public: XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, Allocator* allocator) - : expected_size_(expected_size), + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), actual_size_(actual_size), - allocator_(allocator) { - data_ = const_cast(ptr); - } + allocator_(allocator) {} ~XlaTensorBuffer() override { - if (data_) { - allocator_->DeallocateRaw(data_); + if (data()) { + allocator_->DeallocateRaw(data()); } } - void* data() const override { return data_; } size_t size() const override { return expected_size_; } TensorBuffer* root_buffer() override { return this; } @@ -231,7 +229,6 @@ class XlaTensorBuffer : public TensorBuffer { } private: - void* data_; size_t expected_size_; size_t actual_size_; Allocator* allocator_; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6b8e6bba1e1bbfd773141d33721e4d7e30420a11..093b61629cd0b04d5d8488139b8d7262b739f86d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -375,27 +375,6 @@ tf_xla_py_test( ], ) -tf_xla_py_test( - name = "resampler_ops_test", - size = "small", - srcs = ["resampler_ops_test.py"], - disabled_backends = [ - # TODO(b/74459949) Support BatchDot in CPU backend. - "cpu", - "cpu_ondemand", - ], - # TODO(b/112295522): figure out how to make OSS build pass. - tags = ["no_oss"], - deps = [ - ":xla_test", - "//tensorflow/contrib/resampler:resampler_ops", - "//tensorflow/contrib/resampler:resampler_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", - ], -) - tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -429,13 +408,6 @@ tf_xla_py_test( name = "eager_test", size = "large", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -474,7 +446,6 @@ tf_xla_py_test( "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops", "//tensorflow/python/ops/signal", ], ) diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index 69fb3ec2964a09508e612515b9e291fc14121d68..e9c2d363acab96c0fb968cb7f901ce105ea8703e 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() @@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # similarly for others. self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: @@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAWithL1(self): for dtype in self.float_types: @@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.895489, -1.59555]), var0.eval()) + np.array([-0.895489, -1.59555]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.085339, -0.17989]), var1.eval()) + np.array([-0.085339, -0.17989]), self.evaluate(var1)) def testAdagradDAWithL1_L2(self): for dtype in self.float_types: @@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.046907, -0.093659]), var0.eval()) + np.array([-0.046907, -0.093659]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.004275, -0.009023]), var1.eval()) + np.array([-0.004275, -0.009023]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index ab69319c59fb07e7ce56c3c287a50a6290effdfd..e26483303c3934fd51675cb1fbc998b276caf527 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testTensorLearningRate(self): @@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testSharing(self): @@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values. - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Mix the first and the second adagrad for 3 steps. ada_update1.run() ada_update2.run() ada_update1.run() # Validate updated params (the same as with only 1 Adagrad). self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 058576b3d4b695209952158769162bb24e7ccfce..8bcff9d379d34f8a6bb8b0fdc60b7588c6d80be9 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRate(self): for dtype in self.float_types: @@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSharing(self): for dtype in self.float_types: @@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase): beta1_power, beta2_power = opt._get_beta_accumulators() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of intertwined Adam1 and Adam2. for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) if t % 2 == 0: update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) else: @@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase): var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index 3ed1d41b7121f44dd7470f61180f7a7055369174..961b46375c941bdc3922e460a2f58345086dbceb 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() @@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): for t in range(1, 4): update.run() - self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-2) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-2) self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) @@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() # Run 3 steps of AdaMax for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) update.run() var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 1bc07ace23ccdc83103abe71ee11b72994c75a6d..a37c97e6d374440aeb860b9d02f2d5dd95c91f62 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of AddSign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - var0_np, var0.eval(), half_rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + var0_np, self.evaluate(var0), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 332381c59eed06d5697e58efb1d8fa2b6ef604d2..9a5423c1b2a5df7880453cbb328f6a8174066255 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -218,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + # TF doesn't define these for bf16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_math_ops.xdivy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype)) + + self._testBinary( + gen_math_ops.xlogy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0], + dtype=dtype)) + def testIntOps(self): for dtype in self.signed_int_types: self._testBinary( diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a57d1dc81ea2c9c188b0a3005904738aa8156bf3..5d5e486f616937601214aa169a4c329ab78932c8 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.cached_session() as sess, self.test_scope(): + with self.cached_session(), self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) - d = sess.run(op) + d = self.evaluate(op) batch_size, num_classes = logits.shape freqs_mat = [] @@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, output_dtype=output_dtype) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= 0).sum() == 1000) self.assertTrue((y < 20).sum() == 1000) @@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase): chi2 = self._chi2(probs, freqs) self.assertLess(chi2, 1e-3) + def testStatelessMultinomialIsInRange(self): + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), + 1000, + seed_t, + output_dtype=output_dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testDeterminismMultinomial(self): + # Stateless values should be equal iff the seeds are equal (roughly) + num_samples = 10 + with self.cached_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + pure = stateless_random_ops.stateless_multinomial( + logits, num_samples, seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testEmpty(self): + with self.cached_session(): + with self.test_scope(): + x = random_ops.multinomial( + array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) + y = self.evaluate(x) + self.assertEqual(y.shape, (42, 0)) + + def testEmptyStateless(self): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.zeros([42, 40]), + 0, + seed=seed_t, + output_dtype=dtypes.int32) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertEqual(y.shape, (42, 0)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 88bd58b2da6b2892f898ad10f3467d8ce39d6388..ef2d7af69deeebd5f4c4c7225d7027f8f76bf861 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") output = math_ops.add(input1, input2) - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testAddFromCpuMultiple(self): @@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase): with self.test_scope(): output = math_ops.add(input1, input2) for _ in xrange(10): - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testDeadlock(self): diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2d225ad226cac368042b95eae8fc29e6fd8e82e0..2187f57960f80300d631bdc7eb8fe5e9c8dddeea 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase): x2 = constant_op.constant(p2) with self.test_scope(): c = array_ops.concat([x1, x2], 0) - result = c.eval() + result = self.evaluate(c) self.assertAllEqual(result[:2, :], p1) self.assertAllEqual(result[2:, :], p2) @@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 1) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) def testGradientsSimpleAll(self): @@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 0) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 2) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, concat_dim) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(c.eval(), correct) # Check gradients dc = np.random.randn(*c.get_shape().as_list()) - dxs = sess.run(gradients_impl.gradients(c, xs, dc)) + dxs = self.evaluate(gradients_impl.gradients(c, xs, dc)) self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) def testConcatTuple(self): @@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) - self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) def testConcatNoScalars(self): with self.cached_session(): @@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) - ans = sess.run(off) + ans = self.evaluate(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) s2 = constant_op.constant(5, dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) s2 = constant_op.constant([[]], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[[]], [[]], [[]]]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index d59fd0236f4f7da2bbfb3409342c7f70f8f5d1f6..01cc1b6392845be2418c50d55be97487eb290843 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) # We count the number of cells being added at the locations in the output. # At the center, #cells = kernel_depth * kernel_height * kernel_width @@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) for n in xrange(x_shape[0]): for k in xrange(f_shape[3]): @@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="VALID") - value = output.eval() + value = self.evaluate(output) cache_values = np.zeros(y_shape, dtype=np.float32) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index d1b90f098d7d6574999ba0af44b285f5ad5e4f8d..bf5ea7b1fb6fb3c774c4db20d059f131990d20d3 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) class DenseLayerTest(test.TestCase): @@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase): x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, @@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() test_utils.RunWithWarmup( sess, diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index 50b04daa6b9f4159a3c4bdeecaf900a5b35a833c..e89cf975f5d889091ce92a35165aef55ee5ad4b0 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase): [idx1, idx2], [val1, val2], expected=np.array([[], [], [], []], np.int32)) + def testEmptyIndex(self): + idx1 = np.array([], dtype=np.int32) + idx2 = np.array([[], []], dtype=np.int32) + val1 = np.ndarray(shape=(0, 9), dtype=np.int32) + val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32) + self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2], + expected=np.ndarray( + shape=(0, 9), dtype=np.int32)) + def testSimple1D(self): val1 = np.array([0, 4, 7], dtype=np.int32) val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 63cee550fde9d9d4314b1541fba191df776a4da2..2af32b537ba53723370faf81aebf308a465718c7 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.cached_session() as sess: + with context.graph_mode(), self.cached_session(): with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five - self.assertAllEqual(15, sess.run(product)) + self.assertAllEqual(15, self.evaluate(product)) def testDegenerateSlices(self): with self.test_scope(): diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index e92afd5d6feb42ece233ee521e3a796c4bc3914a..0edd0c35aa2d417a3ed24decbaa0b5d62d35bb62 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -27,8 +27,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import signal -from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.signal import signal from tensorflow.python.platform import googletest BATCH_DIMS = (3, 5) @@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase): def testFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, - spectral_ops.fft) + signal.fft) def testFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, - spectral_ops.fft2d) + signal.fft2d) def testFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), - spectral_ops.fft3d) + signal.fft3d) def testIFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, - spectral_ops.ifft) + signal.ifft) def testIFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, - spectral_ops.ifft2d) + signal.ifft2d) def testIFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), - spectral_ops.ifft3d) + signal.ifft3d) def testRFFT(self): self._VerifyFftMethod( INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), - lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + lambda x: signal.rfft(x, fft_length=[x.shape[-1].value])) def testRFFT2D(self): def _tf_fn(x): - return spectral_ops.rfft2d( + return signal.rfft2d( x, fft_length=[x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod( @@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase): x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) def _tf_fn(x): - return spectral_ops.rfft3d( + return signal.rfft3d( x, fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testRFFT3DMismatchedSize(self): + + def _to_expected(x): + return np.fft.rfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.rfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testIRFFT(self): def _tf_fn(x): - return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), @@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase): def testIRFFT2D(self): def _tf_fn(x): - return spectral_ops.irfft2d( + return signal.irfft2d( x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( @@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase): s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) def _tf_fn(x): - return spectral_ops.irfft3d( + return signal.irfft3d( x, fft_length=[ x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) @@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase): self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + def testIRFFT3DMismatchedSize(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.irfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 8c7edfd277c992c35a81dd5f261256a86352254e..91d77d2f791834346f43aecb60d116ddbf2faa6e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase): enqueue_op.run() for i in xrange(len(elems)): - vals = dequeued_t.eval() + vals = self.evaluate(dequeued_t) self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): @@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([], size.get_shape()) enqueue_op.run() - self.assertEqual(1, size.eval()) + self.assertEqual(1, self.evaluate(size)) dequeued_t.op.run() - self.assertEqual(0, size.eval()) + self.assertEqual(0, self.evaluate(size)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 5b197afd655404e4e36a8b3442f8db60cb1d648d..b078053cdbd6d129645734492d34dd25d28ab3ef 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivAdagradTest_AdagradPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Adagrad for a few steps for _ in range(steps): adagrad_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_FtrlPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run GradientDescent for a few steps for _ in range(steps): sgd_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testFtrlwithoutRegularization(self): for dtype in self.float_types: @@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-2.60260963, -4.29698515]), - var0.eval(), + self.evaluate(var0), float_rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType( np.array([-0.28432083, -0.56694895]), - var1.eval(), + self.evaluate(var1), float_rtol=1e-5, half_rtol=1e-2) @@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5, + np.array([-2.55607247, -3.98729396]), + self.evaluate(var0), + 1e-5, + 1e-5, float_rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) + np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5, + 1e-5) def testFtrlWithL1(self): for dtype in self.float_types: @@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-7.66718769, -10.91273689]), - var0.eval(), + self.evaluate(var0), rtol=1e-4, bfloat16_rtol=1e-1, bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( - np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) + np.array([-0.93460727, -1.86147261]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + np.array([-0.24059935, -0.46829352]), + self.evaluate(var0), + rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) + np.array([-0.02406147, -0.04830509]), + self.evaluate(var1), + rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), + self.evaluate(var0), + rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" @@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): update1 = opt1.apply_gradients([(grads1, var1)]) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # var0 is experiencing L2 shrinkage so it should be smaller than var1 # in magnitude. - self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all()) accum0 = list(opt0._slots["accum"].values())[0].eval() accum1 = list(opt1._slots["accum"].values())[0].eval() # L2 shrinkage should not change how we update grad accumulator. diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index b1891b918c6584abce9da382088ed0037f5319fb..a61827c2ae44de117abad5b7db5c6bcd78fa171e 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testNestedFunctions(self): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_g = Foo(a, b) - result = sess.run(call_g) + result = self.evaluate(call_g) self.assertAllClose(result, expected, rtol=1e-3) def testFunctionMultipleRetvals(self): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testCompileTimeConstantsInDefun(self): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6f51ae33a1b0fc8670ddf0cacb03a3b5a9176a91..dbea9849e217519874352b789588a2af62f1c826 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -75,7 +75,7 @@ def RunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) def MetadataHasXlaRunOp(run_metadata): diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 58622114e4f552fb71db9b040a39b57d7da0037c..0210201fa71a6e790e94667073ab4dba542537a5 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session() as sess: + with self.cached_session(): x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): out_tensor, idx_tensor = array_ops.listdiff( x_tensor, y_tensor, out_idx=index_dtype) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor]) self.assertAllEqual(out, tf_out) self.assertAllEqual(idx, tf_idx) self.assertEqual(1, out_tensor.get_shape().ndims) diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index c6ad67993e8bc196a74c9a328df8c9200c92c575..5dddf6ae4e8c8a3d5e9eb7b2c62298df02a0093c 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase): with self.test_scope(): actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, depth_radius, bias, alpha, beta) - expected_val = expected.eval() - actual_val = actual.eval() + expected_val = self.evaluate(expected) + actual_val = self.evaluate(actual) self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 265c0b6d1412de7be3a5bf5e79129cb330ceb162..776ed899e68ddd3893b8bb30b7c8034297aa6515 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -88,8 +88,8 @@ class LSTMTest(test.TestCase): (basename, m_prev_scalar, c_prev_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM step. - sess.run(variables.global_variables_initializer()) - return sess.run([m, c]) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate([m, c]) def testLSTMCell(self): # Run with all-0 weights, no padding. @@ -173,8 +173,8 @@ class LSTMTest(test.TestCase): (basename, m_init_scalar, c_init_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM layer. - sess.run(variables.global_variables_initializer()) - return sess.run(out_seq) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate(out_seq) def testLSTMLayer(self): # Run with all-0 weights, no padding. diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index f77521a7c49dba39849869ddceb7c0e885147722..3416f7dbd6bdd264bf79785084f981f5b07cb8a9 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) def testNesterovMomentum(self): for dtype in self.float_types: @@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase): var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) var1_np, accum1_np = self._update_nesterov_momentum_numpy( var1_np, accum1_np, 0.9, 0.1, 0.9) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: @@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 77bb839409f0c323ff6ed2c8d6bd105d3003b398..9671ae0ae973ff82d22744a1feb9b4293d94bbdd 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase): ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 sess.run(variables.variables_initializer([v])) - self.assertEqual(8.0, sess.run(out)) + self.assertEqual(8.0, self.evaluate(out)) def test_placeholder_with_default_fed(self): with self.cached_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 86536da7fed0e2309beb32fee9c7c605491592ed..5b35c20027700b34500a31e174061d7087094b61 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -91,8 +91,8 @@ class PowerSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of powersign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class PowerSignTest(xla_test.XLATestCase): ) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index c41b4171e26af4f7ad0237d7407a5b3691299595..63cc51a470164915b2614a06d18ca1850bb64a3c 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -45,15 +45,17 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) - self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + self.assertAllClose( + np.array([-2.60260963, -4.29698515]), self.evaluate(var0)) + self.assertAllClose( + np.array([-0.28432083, -0.56694895]), self.evaluate(var1)) opt_vars = opt.variables() self.assertStartsWith(opt_vars[0].name, var0._shared_name) self.assertStartsWith(opt_vars[1].name, var1._shared_name) @@ -74,14 +76,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) - self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + self.assertAllClose(np.array([-1.60261, -2.296985]), self.evaluate(var0)) + self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) def testProximalAdagradWithL1(self): with self.cached_session(), self.test_scope(): @@ -98,14 +100,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad for _ in range(10): update.run() - self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) - self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + self.assertAllClose(np.array([-6.663634, -9.190331]), self.evaluate(var0)) + self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) def testProximalAdagradWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -122,15 +124,15 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad. for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -141,14 +143,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivAdagradwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 3d808e6b8a71ef9fa60b671d07bfd907e9f58efc..5aec433be765dd0a04bd7ab10d5c39a5a7f48c5c 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -42,15 +42,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent. for _ in range(3): update.run() - self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) - self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + self.assertAllClose(np.array([-0.9, -1.8]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) def testProximalGradientDescentwithoutRegularization2(self): with self.cached_session(), self.test_scope(): @@ -64,15 +64,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent for _ in range(3): update.run() - self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) - self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + self.assertAllClose(np.array([0.1, 0.2]), self.evaluate(var0)) + self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) def testProximalGradientDescentWithL1(self): with self.cached_session(), self.test_scope(): @@ -86,15 +86,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps proximal gradient descent. for _ in range(10): update.run() - self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) - self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + self.assertAllClose(np.array([-1.988, -3.988001]), self.evaluate(var0)) + self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) def testProximalGradientDescentWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -108,15 +108,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Gradient Descent for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -127,14 +127,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivGradientDescentwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 236b1b881dcaffc1a5b0c6395f0605c1d7ef0269..b4d4193e35f9e0e3b23d0242ed076dd811f4ee2b 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -63,7 +63,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) - precision = self.AdjustedNorm(xx.eval() - identity.eval()) + precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity)) self.assertTrue(np.all(precision < 5.0)) def _test(self, dtype, shape, full_matrices): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206..97ffad34c00b8ec16eb1ec109ba5d980e0ce673d 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase): # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) @@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) - y = sess.run(x) + y = self.evaluate(x) def normal_cdf(x): return .5 * math.erfc(-x / math.sqrt(2)) @@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. @@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. @@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = np.diag(range(20)).flatten() # Compare sets to avoid randomness behavior changes but make sure still # have all the values. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index a6b58020126a3297944f199e99b0801387615564..d23fd125163d1afe8c7fd5e008d4b617ff4b2874 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -3382,10 +3382,10 @@ int main(int argc, char** argv) { } // XLA devices register kernels at construction time; create all known devices // to make sure the kernels are registered. - std::vector devices; + std::vector> devices; TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( tensorflow::SessionOptions(), "", &devices)); - tensorflow::DeviceMgr device_mgr(devices); + tensorflow::DeviceMgr device_mgr(std::move(devices)); tensorflow::Device* ignored; TF_QCHECK_OK( diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 132c59c32c9db0c8759bdbb31f8613c3ef88b485..e8fc81bbb5472669c408b8bbdbcdfcdcf461131f 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -91,6 +91,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] + ONES = [np.ones([34000, 2])] def testReduceSumF32(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, @@ -149,6 +150,11 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_REAL_DATA, index_dtype) + def testReduceMeanF16(self, index_dtype): + if np.float16 in self.all_types: + self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES, + index_dtype) + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, self.NONEMPTY_COMPLEX_DATA, index_dtype) diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index 8840a1329a907bddc6ef1cb6dd1c2a6d234def5c..dc3e90b4afa41c08d899ee195d42fb91678bad1c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -76,7 +76,7 @@ class RmspropTest(xla_test.XLATestCase): rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) rms_update = rms_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) mg0 = rms_opt.get_slot(var0, "mg") self.assertEqual(mg0 is not None, centered) @@ -92,12 +92,12 @@ class RmspropTest(xla_test.XLATestCase): self.assertTrue(mom1 is not None) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of RMSProp for _ in range(3): - rms_update.run() + self.evaluate(rms_update) var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( var0_np, @@ -118,14 +118,14 @@ class RmspropTest(xla_test.XLATestCase): # Validate updated params if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) + self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) + self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) + self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) + self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) + self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 897db384b7e8067b0460b5f344201f101a4d8479..17639bd8a755b9e9f5acc77979ac7a4149f112db 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse): class CumsumTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) @@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase): class CumprodTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 21708aa15877647e2a979a5a2674dfb734700df3..ee7ca7e6f196e114ff18e2597145e5c198980b08 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -156,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 46ca371c8abf1cb4710717a183ee12820c4c4ca0..d7e26d79c4c054860ade5c8960a3bca984e020b0 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -79,7 +79,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() self.assertAllEqual( - convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), + self.evaluate(c0)) def testTensorArrayWritePack(self): for dtype in self.numeric_tf_types: @@ -97,7 +98,7 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() - self.assertAllEqual([3, 0, 1], c0.eval().shape) + self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape) def _testTensorArrayWriteConcat(self, tf_dtype): with self.cached_session(), self.test_scope(): @@ -113,8 +114,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.concat() self.assertAllEqual( - convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], - [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], + [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: @@ -341,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): - r0_bad.eval() + self.evaluate(r0_bad) # Test reading from a different index than the one we wrote to w0.read(1) @@ -422,7 +423,7 @@ class TensorArrayTest(xla_test.XLATestCase): w2 = h2.write(0, 5.0) r2 = w2.read(0) r = r1 + r2 - self.assertAllClose(9.0, r.eval()) + self.assertAllClose(9.0, self.evaluate(r)) def _testTensorArrayGradientWriteReadType(self, dtype): with self.cached_session() as session, self.test_scope(): @@ -504,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase): [-0.5, 1.5], # read(0) gradient [20.0, 30.0, 40.0, 50.0], # concat gradient ]) - grad_vals = sess.run(grad_r) # 2 + 2 entries + grad_vals = self.evaluate(grad_r) # 2 + 2 entries self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) @@ -526,7 +527,7 @@ class TensorArrayTest(xla_test.XLATestCase): with ops.control_dependencies([r0_readtwice]): r1_readtwice = w_readtwice.read(0) - self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice)) def _testTensorArrayGradientUnpackRead(self): with self.cached_session() as session, self.test_scope(): @@ -592,7 +593,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() - self.assertAllEqual(3, s.eval()) + self.assertAllEqual(3, self.evaluate(s)) def testWriteCloseTensorArray(self): with self.cached_session(), self.test_scope(): @@ -722,7 +723,7 @@ class TensorArrayTest(xla_test.XLATestCase): # r = acc2.stack() # grad = gradients_impl.gradients(r, [x])[0] - # self.assertAllClose(31.0, grad.eval()) + # self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): with self.cached_session() as session, self.test_scope(): @@ -912,7 +913,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(0, ta.size().eval()) ta = ta.unstack(array_ops.zeros([0, 3, 5])) packed = ta.stack() - self.assertAllEqual([0, 3, 5], packed.eval().shape) + self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero self.assertAllEqual([0, 5], ta.concat().eval().shape) @@ -1041,8 +1042,8 @@ class TensorArrayTest(xla_test.XLATestCase): (read0, read1, size0, size1)) # Tests that the control dependencies was added and executed. - self.assertEqual(1, v0.eval()) - self.assertEqual(1, v1.eval()) + self.assertEqual(1, self.evaluate(v0)) + self.assertEqual(1, self.evaluate(v1)) # Tests correct TensorArray. self.assertEqual(read0_v, 0) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index d612d3b32dd6b0893508413b337ea9ad95ef6dd7..95c9e7ffd4651642781143c2c1940b0e51e1e470 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -481,6 +481,72 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + def quantize_and_dequantize_v2_round_half_up(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_UP") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_up, + np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype), + expected=np.array([ + -102.0 / 127, + -63.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + + def quantize_and_dequantize_v2_round_half_to_even(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1.0, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_TO_EVEN") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_to_even, + np.array( + [ + -0.8, + # The -0.5 should become -63.5 after scaling and with + # rounding this should become -64. But with the test + # unary_ops_test_cpu_ondemand, this fails as the result + # before scaling becomes -63.499996 and gets rounded to -63. + # TODO(sreenik): Some one more familiar with this test needs + # to take a look and resolve this. This works on all other + # variations of the platform like cpu, and gpu. + # -0.5, + 0, + 0.3, + 0.8, + -2, + 33 + ], + dtype=dtype), + expected=np.array( + [ + -102.0 / 127, + # -64.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + def quantize_and_dequantize_v3(x): return array_ops.quantize_and_dequantize_v3( x, -127, 127, num_bits=8, signed_input=True, range_given=False) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 77cdeac8168aa71555955b141852587d62ab59d3..fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -77,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) self.assertAllClose( - np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x)) + np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x)) def testSparseRead1DIndices(self): for dtype in self.numeric_types: @@ -89,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([2, 1]) self.assertAllClose( np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices(self): for dtype in self.numeric_types: @@ -102,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllClose( np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices3DTensor(self): for dtype in self.numeric_types: @@ -115,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] - ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] - ],).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), self.evaluate(x)) def testShape(self): for dtype in self.numeric_types: @@ -229,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[3], [7]]) + self.assertAllEqual(self.evaluate(read), [[3], [7]]) def testScatterSub(self): with self.test_session() as sess, self.test_scope(): @@ -242,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[4], [-1]]) + self.assertAllEqual(self.evaluate(read), [[4], [-1]]) def testScatterMul(self): with self.test_session() as sess, self.test_scope(): @@ -255,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): @@ -268,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[2]]) + self.assertAllEqual(self.evaluate(read), [[2]]) def testScatterMin(self): with self.test_session() as sess, self.test_scope(): @@ -281,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMax(self): with self.test_session() as sess, self.test_scope(): @@ -294,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): @@ -307,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterAddScalar(self): with self.test_session() as sess, self.test_scope(): @@ -320,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterSubScalar(self): with self.test_session() as sess, self.test_scope(): @@ -333,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[-1]]) + self.assertEqual(self.evaluate(read), [[-1]]) def testScatterMulScalar(self): with self.test_session() as sess, self.test_scope(): @@ -346,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant(5, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDivScalar(self): with self.test_session() as sess, self.test_scope(): @@ -359,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[2]]) + self.assertEqual(self.evaluate(read), [[2]]) def testScatterMinScalar(self): with self.test_session() as sess, self.test_scope(): @@ -372,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): @@ -385,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -400,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) def testScatterNdUpdateAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -416,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase): gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) class StridedSliceAssignChecker(object): diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 28d61fb07dcb665fa0dbe3f3e566e291e24fa662..ef55292b1be91a731ec556d7efa9cdf1a696e5cc 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() - sess.run(x) + self.evaluate(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e0171415492658a76b25167107e01300ee4bde88..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -9,6 +9,7 @@ package_group( "//tensorflow/compiler/jit/...", "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", + "//tensorflow/contrib/compiler/...", ], ) @@ -195,8 +196,8 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -204,13 +205,13 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -221,6 +222,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -437,21 +439,15 @@ cc_library( name = "dump_graph", srcs = [ "dump_graph.cc", - "dump_graph_flags.cc", - "dump_graph_flags.h", ], hdrs = [ "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 380c6a7e23da92d949b26876836b999bf6406c6c..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,87 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = - legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc deleted file mode 100644 index 2eb1f8cd849b67922f94cfe3f88456b0d6beeaf8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include -#include - -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static DumpGraphFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new DumpGraphFlags; - flags->tf_dump_graph_prefix = "/tmp/"; - flag_list = new std::vector({ - Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }); - xla::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h deleted file mode 100644 index 80a3307d920f2cc3d668d507786a02e43589f86f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// dump_graph module. -typedef struct { - string tf_dump_graph_prefix; // Path prefix to which graphs dumped during - // debugging should be written. -} DumpGraphFlags; - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 9ef9f49f422ec4dfaf538ac3c0754ba3609d3f88..3dfd3f854c8646ebbf06d3378201d22e8741b7eb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -75,6 +75,25 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr, + graph_def, library); +} + +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library) { + FunctionDefLibrary function_lib = graph_def->library(); + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library)); + graph.ToGraphDef(graph_def); + std::swap(*graph_def->mutable_library(), function_lib); + return Status::OK(); +} + Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index ba99205640ccdc83a3a4d50e3ec474907894a835..91d33fa405834d7f1f8f66180583580f4f2e448a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -33,6 +33,12 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library); + // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional // control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c3841f996f801e855da75b23f01d41674ec51c4d..9784985af83a18619d837528f99a60b98a501ec5 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -95,77 +95,87 @@ TEST(FunctionalizeControlFlow, Conditional) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, - then_fn, else_fn); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -239,75 +249,77 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } -// @function.Defun(noinline=True) -// def increment_fn(x): -// return [x + 1] -// Define the above function, and add it to the given graph. It's used as the -// while loop body in NoinlineLoopBody test. -Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +FunctionDef GetNoinlineFunctionDef() { FunctionDef fdef = FunctionDefHelper::Create( "increment_fn", {"x:int32"}, {"add:int32"}, {}, { @@ -316,8 +328,17 @@ Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { }, {{"add", "add_0:z:0"}}); (*fdef.mutable_attr())["_noinline"].set_b(true); + return fdef; +} + +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { FunctionDefLibrary fdef_lib; - *(fdef_lib.add_function()) = fdef; + *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); NodeDef increment_fn; increment_fn.set_name(node_name); @@ -376,55 +397,88 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { FunctionLibraryDefinition lookup_lib(graph.flib_def()); FunctionLibraryDefinition library(OpRegistry::Global(), {}); // Function increment_fn will be copied from lookup_lib to library. - TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + *(optimized_graph_def.mutable_library()->add_function()) = + GetNoinlineFunctionDef(); - NameAttrList cond_fn, body_fn; - TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &lookup_lib, &optimized_graph_def, &library)); + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK( + AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } +} - // Body graph. +TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), source); TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - NodeDef retval; - retval.set_name("_retval0_RetVal"); - retval.set_op(FunctionLibraryDefinition::kRetOp); - *retval.add_input() = noinline_node_name; - (*retval.mutable_attr())["T"].set_type(DT_INT32); - (*retval.mutable_attr())["index"].set_i(0); - Status status; - scope.graph()->AddNode(retval, &status); - TF_ASSERT_OK(status); - - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } - InstantiationResultForTest result; - // Verify that increment_fn has been copied to library. - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + graph_def.clear_library(); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - // Ignore the function library when comparing the graphs. - expected.clear_library(); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + Status status = + FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library); + EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } // Tests functionalizing OneLoopVar where the loop value is not used post the @@ -467,65 +521,72 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -608,86 +669,95 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); - auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); - auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") .WithControlDependencies(arg0.output), - 3); - auto cond_add = - ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const( - scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - - auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); - auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - - auto one = ops::Const( - scope.WithOpName("while/add/one").WithControlDependencies(identity_x), - 1); - auto two = ops::Const( - scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), - 2); + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); - auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = + ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = + ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -841,177 +911,192 @@ TEST(FunctionalizeControlFlow, Complex) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList outer_cond_fn, outer_body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); - auto y = ops::Add(scope.WithOpName("y"), x, three); - - auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, - TensorShape({})); - - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - - auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Outer condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto ten = ops::Const( - scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Outer body graph. - NameAttrList inner_cond_fn, inner_body_fn; - { - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - - // Find the inner condition and body names. - TF_EXPECT_OK( - FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( - scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); - auto while_op = - ops::While(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); - - auto one_outer = ops::Const( - scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); - auto add_i = - ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(absl::Span{ - while_op[0].op(), while_op[1].op()}), - identity_i, one_outer); - - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto five = ops::Const( - scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); - auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList outer_cond_fn, outer_body_fn; TF_EXPECT_OK( - InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_j = - ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); - auto identity_k = - ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); - - auto mul_jk = - ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); - auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); - auto assign = ops::AssignAddVariableOp( - scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - - auto one = ops::Const( - scope.WithOpName("outer/inner/One") - .WithControlDependencies( - absl::Span{assign.operation}), - 1); - auto add_j = - ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); - auto retval1 = - ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), + 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(absl::Span{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d85b4f5ae0cb9c7d2476158a5830f921742ae980..8bc329229648c5aced8d06c99b170803bb3a90f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -121,13 +121,11 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", - "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", @@ -144,10 +142,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -196,7 +195,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -216,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 2db2514397deca39e6874cf994532a20d2186316..795ea09831e183a26fb3498b9bbaf9c3adaef9ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -50,7 +50,7 @@ class XlaArgOp : public XlaOpKernel { return; } - const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; + const XlaExpression& arg = ctx->xla_context()->args()[index_]; OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, errors::InvalidArgument("Invalid/missing argument expression")); ctx->SetOutputExpression(0, arg); diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a267c0c72fce67d7c22c55a57f8d5ac4ffd2b7e2..0e2f335f3354e3ae6008bdc0ac0b80683fe479c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -115,9 +115,9 @@ class FusedBatchNormGradOp : public XlaOpKernel { // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). auto grad_backprop = - XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype); auto activations = - XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); @@ -151,11 +151,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(scale_dtype); auto converted = - XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -165,19 +165,18 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scratch2 = sum(y_backprop * (x - mean)) auto mul = xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); - converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + converted = XlaHelpers::ConvertElementType(mul, accumulation_type); reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype); x_backprop = xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); scale_backprop = xla::Mul(scratch1, scratch2); } - ctx->SetOutput(0, - XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor()); diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 41f540506ba41fbe7f91393e7b8e26a89e72ef0a..e7f369b761f36a717ea5fb536780af91a8955b1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -107,11 +107,11 @@ class BiasAddGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 47e517a6576d3a848bc41ceb703df2bd778c4a35..5e9280c1fe692037b0a842a92ef5a8c28b854a54 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -43,6 +43,9 @@ namespace { const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ (void)b; \ + (void)lhs_shape; \ + (void)rhs_shape; \ + (void)extend_dimensions; \ return HLO; \ } \ }; \ @@ -103,23 +106,23 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { +xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); } -XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper)); -static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { +xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); } -XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index ad85940920ebb82e72331516e3fe46c79f853892..7199b9b6feb36dd45ef51f4c38463bc715fcc38a 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,10 +21,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -57,11 +60,9 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::XlaBuilder* builder = ctx->builder(); - xla::Shape uniform_shape; int class_dimension; - if (num_samples > 1) { + if (num_samples != 1) { std::array uniform_shape_array = { {batch_size, num_samples, num_classes}}; xla::PrimitiveType uniform_xla_type; @@ -83,16 +84,16 @@ class CategoricalOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); class_dimension = 1; } - xla::XlaOp uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type)); + xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. auto softmax_entries = - xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + xla::Sub(logits, log_uniforms, /*broadcast_dimensions=*/{0, class_dimension}); xla::PrimitiveType xla_output_type; @@ -107,6 +108,16 @@ class CategoricalOp : public XlaOpKernel { ctx->SetOutput(0, argmax); } + virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, + xla::PrimitiveType type, + XlaOpKernelContext* ctx) { + xla::XlaBuilder* builder = ctx->builder(); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + return xla::Log(-xla::Log(uniforms)); + } + private: TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); }; @@ -115,5 +126,48 @@ class CategoricalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), CategoricalOp); +class StatelessCategoricalOp : public CategoricalOp { + public: + explicit StatelessCategoricalOp(OpKernelConstruction* ctx) + : CategoricalOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, + XlaOpKernelContext* ctx) override { + xla::XlaOp seed = ctx->Input(2); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::XlaBuilder* builder = ctx->builder(); + if (uniform_shape.element_type() == xla::BF16) { + uniform_shape.set_element_type(xla::F32); + } + auto uniforms = xla::StatelessRngUniform( + {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), + XlaHelpers::One(builder, DT_FLOAT)); + return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + CategoricalOp::Compile(ctx); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); +}; + +REGISTER_XLA_OP(Name("StatelessMultinomial") + .CompileTimeConstantInput("num_samples") + .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("Tseed", DT_INT32), + StatelessCategoricalOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c9a1be494066e4f935a1d818bc86c86333e34fae..641fefafb357f6ad10483c454600f3dadd4f8cb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" @@ -65,60 +64,63 @@ xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 // -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 +// The first step is to create a iota A with iota_dimension = 2 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and divide B it by 2 to get -// 0 0 1 1 2 2 +// and another iota B with iota_dimension = 3 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// and divide B by 2 to get +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 // -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and B and return the result at the beginning of the +// comment. xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, xla::XlaBuilder* builder) { xla::Shape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dimensions(filter_shape.dimensions_size() - 1); - int64 input_feature = - filter_shape.dimensions(filter_shape.dimensions_size() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. + // Create two iotas with the shape of the expanded filter, one of them with + // the iota dimension chosen as the feature dimension, and the other a iota + // with the iota dimension chosen as the expanded output feature dimension. + std::vector iota_dimensions(expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions); + xla::XlaOp input_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2); + xla::XlaOp expanded_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1); + + // Divide 'expanded_feature_iota' by the depthwise_multiplier to create + // [0 0 1 1 2 2] ... in the example in the function comment. expanded_feature_iota = xla::Div(expanded_feature_iota, XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, depthwise_multiplier)); - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - std::vector expanded_feature_broadcast_dims( - expanded_filter_shape.dimensions().begin(), - expanded_filter_shape.dimensions().end()); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dimensions_size() - 2}); + // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a + // diagonal predicate. + return xla::Eq(expanded_feature_iota, input_feature_iota); } // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index d820528a43064e327cb90e5a2889f77ab1f3f3e2..eafdba876ae9e2c38694f065cf83bb3725b8460e 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 49c12fc232092873b69961644a059abc6035f64f..ee79cbc70da269be7586c47b4fd33c901f4fd581 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index b2f6ef43fa9765b0d6da8e3215cbea5b56b4fe05..6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -113,8 +113,20 @@ class DynamicStitchOp : public XlaOpKernel { } } int number_of_indices = max_index + 1; - OP_REQUIRES(ctx, number_of_indices > 0, - errors::InvalidArgument("no indices supplied")); + int64 result_rank = 1 + data0_shape.dims() - indices0_shape.dims(); + if (number_of_indices == 0) { + std::vector result_shape(result_rank); + for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { + result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d); + } + xla::PrimitiveType element_type = + ctx->input_xla_type(ctx->num_inputs() - 1); + xla::Literal empty_literal = xla::Literal::CreateFromShape( + xla::ShapeUtil::MakeShape(element_type, result_shape)); + ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal)); + return; + } + // Construct the reverse mapping, for each index, of which slice of which // input it comes from. std::vector src_input_vector(number_of_indices); @@ -157,12 +169,9 @@ class DynamicStitchOp : public XlaOpKernel { // Set up the vectors for slicing: the first dimension will vary // slice by slice, and the rest take the full common extra shape. - std::vector slice_start(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector slice_limit(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector stride(1 + data0_shape.dims() - indices0_shape.dims(), - 1); + std::vector slice_start(result_rank); + std::vector slice_limit(result_rank); + std::vector stride(result_rank, 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index c68b0bfd7961892294c2931e5c4c44de534a7740..29687c7b82f92d9f336854c4575746589c63b64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index cdba6680dee3fade5bdf0c453ed672b653072b0d..142be030f737f105980ab9c80a5a849e1ca6eb47 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -260,19 +260,19 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { xla::XlaOp below_min = xla::Lt(input, nudged_input_min); xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); xla::XlaOp reduce1 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::ConvertElementType(select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type); ctx->SetOutput(1, output1); xla::XlaOp above_max = xla::Gt(input, nudged_input_max); xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); xla::XlaOp reduce2 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::ConvertElementType(select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 9b06357d9b78be6d7b64e88a97f45f6c19176fc8..6df8b5367d2390e65995beb1583b225755e6ee9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -50,11 +51,36 @@ class GenericFftOp : public XlaOpKernel { errors::InvalidArgument("input must be at least 1 dimensional")); std::vector fft_length; + xla::XlaOp input = ctx->Input(0); if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); OP_REQUIRES(ctx, fft_length.size() == fft_rank_, errors::InvalidArgument("fft_length must be length ", fft_rank_, " vector")); + + // Zero pad or truncate the axes we're doing FFT on. + absl::InlinedVector slice_sizes = input_shape.dim_sizes(); + std::vector> padding_sizes(slice_sizes.size()); + std::vector expected_sizes = fft_length; + // IRFFT wants the innermost axis to be n / 2 + 1. + if (fft_type_ == FftType::IRFFT) { + expected_sizes[fft_rank_ - 1] = fft_length[fft_rank_ - 1] / 2 + 1; + } + for (int i = 0; i < fft_rank_; i++) { + int index = input_shape.dims() - fft_rank_ + i; + if (input_shape.dim_size(index) > expected_sizes[i]) { + slice_sizes[index] = expected_sizes[i]; + } else { + padding_sizes[index].second = + expected_sizes[i] - input_shape.dim_size(index); + } + } + + std::vector start_indices(input_shape.dims(), 0); + std::vector strides(input_shape.dims(), 1); + input = xla::Pad(xla::Slice(input, start_indices, slice_sizes, strides), + XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), + xla::MakeEdgePaddingConfig(padding_sizes)); } else { // Innermost axis provides the FFT length. for (int i = 0; i < fft_rank_; i++) { @@ -63,7 +89,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(input, fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 56da50f140893c68c8a1556853884720b21c7229..b5e083912555c865b5eadc7697075c9ca4451ca9 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -72,7 +72,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.shape = resource->shape(); OP_REQUIRES(ctx, arg.initialized, errors::Unimplemented("Uninitialized arguments: ", arg.name)); - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index b49b2516d8b829a550071bc7580d350328833f32..e9bb0a77e99d144863b027bd214081316d61c314 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -191,12 +191,11 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); - auto converted = - XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(b, reduce, type); + auto output = XlaHelpers::ConvertElementType(reduce, type); output = xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 0c7ca602bfacd598dada0303d3a3e77fe7f1b0fc..5a10c52ba8b6d4fab73f0dda67cbd52fd625e76b 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index e310db2162da0997204f85bc3ca42e7b0460e1e3..e2c05b648bb194b1b452c527ddb1a2c5995b1217 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -30,7 +30,9 @@ limitations under the License. namespace tensorflow { namespace { -// The logic below uses a custom-call to implement argmax. +// The logic below uses a custom-call to implement argmax when possible. When +// custom-call is not allowed or input shapes are not supported, this kernel +// falls back to using XLA HLO native ArgMax. // // Also see b/29507024 for first-class XLA support for indexing ops. class ArgMaxCustomCallOp : public XlaOpKernel { @@ -50,27 +52,40 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // overhead, when compiling ahead-of-time. int64 dim; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); - OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); - OP_REQUIRES( - ctx, dim < input_shape.dims(), - errors::InvalidArgument("dim must be < input rank (", - input_shape.dims(), "), but got: ", dim)); - const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES(ctx, dim_size > 0, + + const int input_dims = input_shape.dims(); + const int axis = dim < 0 ? dim + input_dims : dim; + OP_REQUIRES(ctx, axis >= 0 && axis < input_dims, + errors::InvalidArgument("Expected dimension in the range [", + -input_dims, ", ", input_dims, + "), but got ", dim)); + + const int64 axis_size = input_shape.dim_size(axis); + OP_REQUIRES(ctx, axis_size > 0, errors::InvalidArgument( "Reduction axis ", dim, " is empty in shape: ", input_shape.DebugString())); - // The output shape is the input shape contracted along dim. + const DataType dtype = output_type(0); + xla::PrimitiveType output_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type)); + + // Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input + // shape isn't supported. + if (!ctx->compiler()->options().allow_cpu_custom_calls || + (input_dims != 1 && input_dims != 2)) { + xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + ctx->SetOutput(0, output); + return; + } + + xla::XlaOp output; + // The output shape is the input shape contracted along axis. TensorShape output_shape; for (int d = 0; d < input_shape.dims() - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); } - // For now we use a custom-call, only for the 1d and 2d cases. - OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), - errors::InvalidArgument( - "ArgMax implementation requires a CustomCall on CPU")); xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. @@ -84,7 +99,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { args.push_back(xla::ConstantLiteral( &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(axis))); } // The argmax function expects row-major layout. @@ -101,24 +116,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { } // Tell XLA to call the custom code, defined in - // index_ops_kernel_argmax_float_1d.cc. - xla::XlaOp output; - switch (input_shape.dims()) { - case 1: - output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, - xla_shape, arg_shapes); - break; - case 2: - output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, - xla_shape, arg_shapes); - break; - default: - OP_REQUIRES(ctx, false, - errors::Unimplemented( - "Argmax is only implemented for 1d and 2d tensors" - ", but got shape: ", - input_shape.DebugString())); + // index_ops_kernel_argmax_float_{1, 2}d.cc. + if (input_dims == 1) { + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); + } else { + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); } + output = xla::ConvertElementType(output, output_type); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index f028e361bccd51de0bd69a1d2227c7afaed53455..93f029731c34e84000a3dc00df8af05654cccf2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -37,12 +37,11 @@ class L2LossOp : public XlaOpKernel { // output = sum(t ** 2) / 2 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto t = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto square = xla::Mul(t, t); auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), dims); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto deconverted = XlaHelpers::ConvertElementType(reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); ctx->SetOutput(0, xla::Div(deconverted, two)); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 87ee2d3aede50eb24e65570f106d49030e1d4236..987901d82b3f3798dd52f18ef2497b8f0cf80b11 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -49,16 +49,14 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); - auto converted = - XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto scale = xla::Pow( xla::Add(xla::ConstantR0(builder, bias_), @@ -138,15 +136,14 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + XlaHelpers::ConvertElementType(in_image, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto norm = xla::Add(xla::ConstantR0(builder, bias_), @@ -157,15 +154,13 @@ class LRNGradOp : public XlaOpKernel { xla::Div(out_image, norm)), in_grads); - auto converted_dy = - XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto dy_reduced = - XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); + auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0)); xla::XlaOp gradients = xla::Add( xla::Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 8dfd7de591c4a3c4768dd60b41e03d294ad49397..2dd0a710e47ec8cad6153402fdb3be59f5868205 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -61,11 +61,11 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); - xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); + xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); + xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); - auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(iota_n, iota_m); // If num_lower or num_upper are negative, include all lower/upper // diagonals. diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6..4f980b6d14ed667bdf4756ed740894098cae5919 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index f4def11d08c31513aec5aad15187016a7294c2fd..90c0ebefb24ec2c4378782e9b15d3f57c33032a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" namespace tensorflow { namespace { @@ -29,7 +29,7 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = TriangularSolve( + auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc index 94b51e1a586c6cf623c181abf200b91851c7ba05..71920bf5c1e6aa5981aafa8b611cc01c0917e02b 100644 --- a/tensorflow/compiler/tf2xla/kernels/permute_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -75,8 +75,7 @@ class DataFormatVecPermuteOp : public XlaOpKernel { } auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); if (input_rank == 2) { - keys = xla::BroadcastInDim( - keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + keys = xla::BroadcastInDim(keys, {4, 2}, {0}); } auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); auto output = xla::GetTupleElement(sorted, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a259da6383d461fd11b0d79096bf66aae7ddef06..06c6cc37ec90192486ba15010bfeb763a9ffb987 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,10 +185,6 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,10 +456,6 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP( diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 6f4ed496a1774dde68dd9d5fbd37995d615b678c..7fe102428db1cc5ce16037f56fa301d1941da8e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" @@ -26,12 +27,26 @@ limitations under the License. namespace tensorflow { namespace { +enum QuantizerRoundMode { + // Round half up: if the fraction of y is exactly 0.5, then + // round(y) = y + 0.5 + // E.g., -5.5 gets rounded to -5, -5.4 goes to -5, + // 5.4 goes to 5, and 5.5 goes to 6. + ROUND_HALF_UP, + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; + class QuantizeAndDequantizeOp : public XlaOpKernel { public: explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); + round_mode_ = ROUND_HALF_TO_EVEN; } void Compile(XlaOpKernelContext* ctx) override { @@ -117,8 +132,17 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // in that case they were measured from the tensor. input = Clamp(min_range, input, max_range); } - xla::XlaOp result = - Floor((input - min_range) * scale + half) * inverse_scale + min_range; + xla::XlaOp result; + switch (round_mode_) { + case ROUND_HALF_TO_EVEN: { + result = xla::RoundToEven(input * scale) * inverse_scale; + break; + } + case ROUND_HALF_UP: { + result = Floor(input * scale + half) * inverse_scale; + break; + } + } ctx->SetOutput(0, result); } @@ -126,6 +150,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { int64 num_bits_ = -1; bool signed_input_; bool range_given_; + QuantizerRoundMode round_mode_; }; class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { @@ -136,6 +161,20 @@ class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), errors::InvalidArgument("num_bits is out of range: ", num_bits_, " with signed_input_ ", signed_input_)); + string round_mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); + OP_REQUIRES( + ctx, + (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"), + errors::InvalidArgument("Round mode string must be " + "'HALF_UP' or " + "'HALF_TO_EVEN', is '" + + round_mode_string + "'")); + if (round_mode_string == "HALF_UP") { + round_mode_ = ROUND_HALF_UP; + } else if (round_mode_string == "HALF_TO_EVEN") { + round_mode_ = ROUND_HALF_TO_EVEN; + } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 415ce9b77ffeac8a6a5f3c23537afb16c1d3567c..8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 107fa62967a55dffcfff8728b65338564e5202d2..65e158d64fdd7df62d50b81c9e488b2d03476fb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -113,12 +113,21 @@ class MeanOp : public XlaReductionOp { xla::Add(scalar_lhs, scalar_rhs); } - xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) override { - auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), - num_elements_reduced); - return reduce_output / divisor; + xla::XlaOp BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce) override { + if (dimensions_to_reduce.empty()) { + return reduce_output; + } + auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + for (int i = 1; i < dimensions_to_reduce.size(); i++) { + auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + divisor = xla::Mul(divisor, size); + } + divisor = xla::ConvertElementType(divisor, xla_reduction_type_); + return XlaHelpers::ConvertElementType(reduce_output / divisor, + input_type(0)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 466e79828d111ee7cadcf713703e8f252c63e62c..af716eab79886791e8507a84984b7ca60865d00e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,13 +48,14 @@ class XlaReductionOp : public XlaOpKernel { const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired - // computation should be added to 'builder'. Argument 'reduce_output' is the - // output of the reduction. 'num_elements_reduced' is the number of elements - // that contributed to the reduction. Returns the transformed reduction - // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced); + // computation should be added to 'builder'. Argument 'input' is the original + // input of the reduction; 'reduce_output' is the output of the reduction. + // Returns the transformed reduction output. Defaults to returning + // 'reduce_output' converted to the input type. + virtual xla::XlaOp BuildFinalizer( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 118f2798d559f43acb7f6394a7337426164325ef..2ca2a85244b4edfe75db3d4fff6c2058adc2bf71 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,12 +35,13 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } -// Unless BuildFinalizer is overridden the reduction has no -// finalizer. -xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) { - return reduce_output; +// The default finalizer converts the results back into the input type. This can +// be overridden. +xla::XlaOp XlaReductionOp::BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& /*input*/, + const xla::XlaOp& reduce_output, + const std::vector& /*dimensions_to_reduce*/) { + return XlaHelpers::ConvertElementType(reduce_output, input_type(0)); } void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { @@ -71,7 +72,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; - int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { int64 index = axes[i]; OP_REQUIRES(ctx, @@ -82,7 +82,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { index = (index + data_shape.dims()) % data_shape.dims(); bitmap[index] = true; xla_axes.push_back(index); - num_elements_reduced *= data_shape.dim_size(index); } std::vector final_shape; @@ -118,8 +117,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto finalized = BuildFinalizer(b, data, reduce, xla_axes); auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 847704608fb32b43ffb61f87556d5231b9e39cde..54d34a38abc4948a1a08197d72e3e7f763649093 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,9 +43,6 @@ namespace { using xla::XlaOp; -// TODO(b/112295522): note that sampling from image boundary is not currently -// being handled properly. - // Calculates the bilinear weight tensor, given basis ratio (px, py) of the // sampling position: // W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] @@ -70,11 +66,8 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, std::vector last_two_dims_indices = {(broadcast_dims_size - 2), (broadcast_dims_size - 1)}; - xla::Shape broadcast_shape = - xla::ShapeUtil::MakeShape(xla_type, broadcast_dims); - auto broadcast_first_term = - xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices); + xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices); // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the @@ -85,7 +78,7 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); auto broadcast_ratio = - xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices); + xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices); auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; @@ -96,7 +89,7 @@ XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, sign_change = xla::ConvertElementType(sign_change, xla_type); auto broadcast_sign_change = - xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices); + xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices); auto flipped = first_term_subtract_weights * broadcast_sign_change; @@ -232,21 +225,19 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector weights_with_channels_dims = reshaped_weights_dims; weights_with_channels_dims.push_back(data_channels); - auto weights_with_channels_shape = - xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims); std::vector reshaped_weights_indices(reshaped_weights_dims.size()); std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), 0); // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. auto broadcast_reshaped_weights = xla::BroadcastInDim( - reshaped_weights, weights_with_channels_shape, reshaped_weights_indices); + reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); std::vector grad_output_indices(warp_dims_without_last_dims.size()); std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); grad_output_indices.push_back(weights_with_channels_dims.size() - 1); XlaOp broadcast_grad_output = xla::BroadcastInDim( - grad_output, weights_with_channels_shape, grad_output_indices); + grad_output, weights_with_channels_dims, grad_output_indices); auto grad_output_multiply_weights = broadcast_grad_output * broadcast_reshaped_weights; @@ -294,13 +285,10 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, std::vector warp_dims_without_last_dims(warp_dims.begin(), warp_dims.end() - 1); + // With dimension [batch, dim_0, ...dim_n, 4] std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; neighbor_broadcast_dims.push_back(4); - // With dimension [batch, dim_0, ...dim_n, 4] - auto neighbor_broadcast_shape = - xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); - // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] auto neighbors_data = Gather2by2Neighbors( ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); @@ -326,7 +314,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {0, 0, -1, 1}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_cxfy - img_fxfy @@ -334,7 +322,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {-1, 1, 0, 0}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_cxcy - img_cxfy @@ -342,7 +330,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {0, -1, 0, 1}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // img_fxcy - img_fxfy @@ -350,7 +338,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::BroadcastInDim( xla::ConvertElementType( xla::ConstantR1(ctx->builder(), {-1, 0, 1, 0}), data_type), - neighbor_broadcast_shape, {last_warp_dim}), + neighbor_broadcast_dims, {last_warp_dim}), neighbors_data, dot_dims, /*precision_config=*/nullptr); // Slice out x and y. @@ -421,12 +409,13 @@ class ResamplerOp : public XlaOpKernel { OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, errors::InvalidArgument( "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); XlaOp data = ctx->Input("data"); XlaOp warp = ctx->Input("warp"); // Find the coordinates of the top left corner for the 2x2 region to be - // sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the + // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the // last dimension of size 2 in turn is [x, y]. XlaOp top_left = xla::ConvertElementType(warp, xla::U32); @@ -457,10 +446,54 @@ class ResamplerOp : public XlaOpKernel { dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + // The dimension is [batch, dim_0, ...dim_n, data_channels]. auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, /*precision_config=*/nullptr); - ctx->SetOutput(0, blended_pixels); + // Handle out of boundary cases by constructing a predicate mask array based + // on the in-bound condition, and output 0 for the blended pixel value if + // out-bound. The dimension is the same as top_left: [batch, dim_0, + // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate. + + auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp)); + + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dim_size(2) - 1), + /*height=*/static_cast(data_shape.dim_size(1) - 1)}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, + ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which + // is the dimension of the result: + // [batch, dim_0, ...dim_n, data_channels]. + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(data_channels); + + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims); + auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros); + + ctx->SetOutput(0, result); } }; @@ -473,6 +506,8 @@ class ResamplerGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); } + // TODO(b/112295522): note that sampling from image boundary is not currently + // being handled properly. void Compile(XlaOpKernelContext* ctx) override { TensorShape data_shape_tf = ctx->InputShape("data"); OP_REQUIRES(ctx, data_shape_tf.dims() == 4, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 6970dd0a00641c9f88571561501fb3454fb3eab3..e4046c795577983bff1a8053743bf4d3a258e583 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -47,8 +47,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - XlaContext& xla_context = XlaContext::Get(ctx); - xla_context.SetRetval(index_, ctx->InputExpression(0)); + ctx->xla_context()->SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 7ff3e9163811434e8d621795c22bf8304ba7a1ed..d7b38e86cc985d608116488f9e76756a8e904f9c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index b5fd7850bfca01868273c40cbf86188bd815be5b..4b9e1a578be2445091228953df7e5c5e82b42c28 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -39,8 +39,8 @@ namespace { // TODO(phawkins): implement double-sized windowed reductions in XLA and remove // the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; class ScanOp : public XlaOpKernel { public: @@ -103,11 +103,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = xla::ReduceWindowWithGeneralPadding( - XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, + XlaHelpers::ConvertElementType(ctx->Input(0), dtype), init, *reducer, + window_dims, window_strides, /*base_dilations=*/{}, /*window_dilations=*/{}, padding); - output = - XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); + output = XlaHelpers::ConvertElementType(output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index a7f5a8f1698b9d02560de427d356e9e6be5caa7c..84470b230d421658e0d79dcecb175a24155f49b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -42,7 +42,7 @@ SendOp::SendOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void SendOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); xla::Send(ctx->Input(0), channel); @@ -73,7 +73,7 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void RecvOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 60b011ba6d9b64a89e4228ba2a213f72b67a462d..b1fa2915d59e4e5e2f2523e20e9a37898d087117 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index d6bd927135c013ac1ec3f6547aef358dc2741896..20da8033536e3af3da0fcb216db45f808cacc1d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -71,7 +71,7 @@ class SoftmaxOp : public XlaOpKernel { auto reduce = xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum = XlaHelpers::ConvertElementType(reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) @@ -111,11 +111,11 @@ std::pair CrossEntropyWithLogits( // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum_exp = XlaHelpers::ConvertElementType(reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = xla::Log(sum_exp); @@ -126,11 +126,10 @@ std::pair CrossEntropyWithLogits( // (The subtraction broadcasts along the batch dimension.) auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); auto mul = xla::Mul(xla::Neg(labels), sub); - auto sum = - xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto loss = XlaHelpers::ConvertElementType(b, sum, type); + auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 7b96b43ad834c28aa0283c5ef4ac516618ca5134..8e9e4daf99d3dd3b8e149e3f3e5f6c27665c0fcb 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -69,7 +69,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, } TensorShape stack_shape; - stack_shape.AddDim(resource->tensor_array_size()); + stack_shape.AddDim(resource->max_array_size()); stack_shape.AppendShape(elem_shape); if (!resource->initialized()) { @@ -97,10 +97,10 @@ class StackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - int64 size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + int64 max_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size)); OP_REQUIRES( - ctx, size >= 0, + ctx, max_size >= 0, errors::InvalidArgument( "XLA compilation requires a fixed stack size upper bound. If " "you are using tf.while_loop, set the maximum_iterations parameter " @@ -108,14 +108,9 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::XlaOp value; - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* resource; - string name = absl::StrCat("Stack: ", stack_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - TensorShape(), value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &resource)); + XlaResource* resource = + ctx->xla_context()->AddResource(XlaResource::CreateStack( + /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size)); ctx->SetResourceOutput(0, resource); } diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5db52781be473a9a1aef0adf105e3edf69ccd306..50653d7b3973b73d580cdeec5d71943b575d7cc9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 252967a74649f5089f0cb0a9166b1d2b6e094f27..939d7e19515a1cb41e3e23e9d1fa957ae09ecab7 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -61,8 +61,8 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size() >= 0) - << resource->name() << " size " << resource->tensor_array_size(); + TF_RET_CHECK(resource->max_array_size() >= 0) + << resource->name() << " size " << resource->max_array_size(); if (!resource->initialized()) { TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); @@ -78,7 +78,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( @@ -114,7 +114,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); - shape->InsertDim(0, resource->tensor_array_size()); + shape->InsertDim(0, resource->max_array_size()); return Status::OK(); } @@ -166,13 +166,10 @@ class TensorArrayOp : public XlaOpKernel { value = xla::Broadcast(zero, ta_shape.dim_sizes()); } - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* var; - string name = absl::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, shape, value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &var)); + XlaResource* var = + ctx->xla_context()->AddResource(XlaResource::CreateTensorArray( + /*name=*/absl::StrCat("TensorArray: ", tensor_array_name_), dtype_, + shape, /*initial_value=*/value, /*max_array_size=*/size)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -517,14 +514,13 @@ class TensorArraySplitOp : public XlaOpKernel { xla::XlaOp ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES( - ctx, lengths.size() == resource->tensor_array_size(), - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size(), ")")); + OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->max_array_size(), ")")); const xla::XlaOp value = ctx->Input(1); const xla::XlaOp flow = ctx->Input(3); @@ -562,8 +558,7 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = - static_cast(var->tensor_array_size()); + size_tensor.scalar()() = static_cast(var->max_array_size()); ctx->SetConstantOutput(0, size_tensor); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 8a0c94cfae1b298bd62a3231caf39ecf9b32880e..ee3bdf3394e37c757f31724e73e95417becaa534 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 7077c2e3a546e198bdb4ff944ea531f3158810f2..960c1462ceb8c00a2d6c96564f6c985fd1caef0f 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -320,9 +320,8 @@ class ResourceApplyAdagradDA : public XlaOpKernel { xla::XlaOp lr = ctx->Input(4); xla::XlaOp l1 = ctx->Input(5); xla::XlaOp l2 = ctx->Input(6); - xla::XlaBuilder* const b = ctx->builder(); xla::XlaOp global_step = - XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + XlaHelpers::ConvertElementType(ctx->Input(7), dtype_); accum = accum + grad; squared_accum = squared_accum + xla::Square(grad); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 559414eeaa5fec75e5a9d1866baaf738c024cd15..ce007fc04a818869686b9936a1607cee42665e87 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -64,7 +64,7 @@ Status MakeXlaCompilerArgumentsFromInputs( if (!arg.initialized) { *has_uninitialized_vars = true; } - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index a9f88a6df2539b06ff44fb0aa49c2f2ae1389100..ad8e707e1116d01d492575986a7ab9586022f6b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -89,13 +89,10 @@ class XlaBroadcastHelperOp : public XlaOpKernel { lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); broadcast_shape[dim] = min_rank_shape->dim_size(i); } - xla::PrimitiveType type = context->input_xla_type(0); - xla::Shape broadcast_xla_shape = - xla::ShapeUtil::MakeShape(type, broadcast_shape); if (broadcast_lhs) { - lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + lhs = xla::BroadcastInDim(lhs, broadcast_shape, broadcast_dims); } else { - rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + rhs = xla::BroadcastInDim(rhs, broadcast_shape, broadcast_dims); } context->SetOutput(0, lhs); context->SetOutput(1, rhs); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 1ce3930fd1cd91f8e8dfb765b49be2dc969d1bd7..3e7a761120317ff85947559b7b2e52be9232afb7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -17,20 +17,6 @@ filegroup( load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") -cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", - ], -) - cc_library( name = "broadcast", srcs = ["broadcast.cc"], @@ -52,8 +38,6 @@ cc_library( srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], deps = [ - ":batch_dot", - ":triangular_solve", ":util", ":while_loop", "//tensorflow/compiler/xla:literal", @@ -63,6 +47,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:lib", ], ) @@ -87,7 +74,6 @@ cc_library( srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ - ":batch_dot", ":util", ":while_loop", "//tensorflow/compiler/xla:literal_util", @@ -99,7 +85,8 @@ cc_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -124,51 +111,6 @@ cc_library( ], ) -cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 - deps = [ - ":triangular_solve", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "util", srcs = ["util.cc"], @@ -187,29 +129,6 @@ cc_library( ], ) -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "while_loop", srcs = ["while_loop.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb9807f6dd71abe7789b2672e29e905..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index 3e402ef855cd7c114332d84032bc869232404fc8..be31f116686a2e302ece730e9d03312a45888a61 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -80,10 +80,8 @@ xla::StatusOr BroadcastTo(xla::XlaOp input, broadcast_dim = broadcast_shape_size - broadcast_dim - 1; } absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::BroadcastInDim( - input, - xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), - broadcast_dims); + xla::XlaOp output = + xla::BroadcastInDim(input, broadcast_shape, broadcast_dims); if (broadcast_shape != output_dims) { output = xla::Reshape(output, output_dims); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index ab3d0a566839343828d176d9a46672824e425613..550ab5b05693b79e60e49577309328ac6846d3f9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -101,10 +102,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -122,10 +120,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -185,9 +180,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 6b3f2b6e065b5c99e2d0248237369ecc30188aa5..d6007748609fdd161cb89692a167eb7ed12fe00c 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -191,12 +191,8 @@ xla::StatusOr QRBlock( auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -278,12 +274,9 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); auto z = xla::Mul( -beta, v + wyv, @@ -375,23 +368,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75b0a5a6e04b204822b6f084013cd8b..c0bd172d17c192435ba8ee196f9def0491c0bf5c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +122,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d49581209e608b98606e02709c5876..aec8061cb4322b8d315b6cdc80c7fff1e0cb4cb1 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index c9f486edc8d30954619db0967c988fe8e26938de..fef97b98c376d9df8bbfd9cb6651216895e46bf4 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,11 +1,13 @@ licenses(["notice"]) # Apache 2.0 +package_group( + name = "friends", + includes = ["//tensorflow:internal"], +) + package( default_visibility = [ - "//learning/deepmind/public/wavenet/python:__subpackages__", - "//learning/deepmind/research/alphastar:__subpackages__", - "//learning/tfx:__subpackages__", - "//tensorflow:internal", + ":friends", ], ) diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index f7e34a5b40c2f9244c029ed325a76322b8cf54dd..0b231ea8e7a2d8e303e91911e2e0a36fc83e78b4 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 425e769346ffcbc548495d93cb7adc779f860110..c7341cf8b9e8d7a06fd304ae8766420d20f0c16e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -26,7 +26,7 @@ limitations under the License. // Forward-declare, rather than include, to reduce code size for users that // never use this functionality. namespace xla { -class ProgramShape; +class ProgramShapeProto; class HloProfilePrinterData; } @@ -84,7 +84,7 @@ class XlaCompiledCpuFunction { void set_result_names(const char** result_names) { result_names_ = result_names; } - void set_program_shape(const xla::ProgramShape* program_shape) { + void set_program_shape(const xla::ProgramShapeProto* program_shape) { program_shape_ = program_shape; } const xla::HloProfilePrinterData* hlo_profile_printer_data() const { @@ -122,7 +122,7 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; @@ -206,8 +206,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ @@ -264,7 +270,7 @@ class XlaCompiledCpuFunction { // Returns the shape of the args and results. May return nullptr if the // program shape isn't available. - const xla::ProgramShape* ProgramShape() const { return program_shape_; } + const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } bool hlo_profiling_enabled() const { return hlo_profile_printer_data_ != nullptr; @@ -287,11 +293,6 @@ class XlaCompiledCpuFunction { // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - // - // For now we need to keep around the args_ array because there is code that - // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using buffer_table_ as the sole storage for the - // arguments. const int32* const arg_index_table_; // The number of incoming arguments. @@ -310,7 +311,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index a08d030ce710bdb97910c01a64f80199fc10d649..ee461a3c07d4db514c7697e005a9371be4b54dd0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -158,7 +158,8 @@ Status BuildComputation( xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, - std::vector* resource_updates) { + std::vector* resource_updates, + xla::Shape* output_shape) { // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -176,6 +177,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); + + // Keeps track of which retvals have layout to update. The first element is + // the output index, second element is the new layout. + std::vector> retval_to_update_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -202,10 +207,12 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + retval_to_update_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); } + elems.push_back(value); break; } @@ -297,6 +304,21 @@ Status BuildComputation( return computation_status.status(); } *computation = computation_status.ConsumeValueOrDie(); + + TF_ASSIGN_OR_RETURN(const auto& program_shape, + computation->GetProgramShape()); + *output_shape = program_shape.result(); + // Update the output layout to the layout of retval. + for (auto& update : retval_to_update_layout) { + if (!always_return_tuple && elems.size() == 1) { + *output_shape->mutable_layout() = update.second; + continue; + } + + xla::Shape* output_sub_shape = + xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); + *output_sub_shape->mutable_layout() = update.second; + } return Status::OK(); } @@ -304,10 +326,10 @@ Status BuildComputation( bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.initialized, other.tensor_array_size, + other.initialized, other.max_array_size, other.tensor_array_gradients)) { return false; } @@ -337,8 +359,8 @@ string XlaCompiler::Argument::HumanString() const { string output = absl::StrCat("kind=resource", common, " resource_kind=", XlaResource::KindToString(resource_kind), " initialized=", initialized); - if (tensor_array_size >= 0) { - absl::StrAppend(&output, " tensor_array_size=", tensor_array_size); + if (max_array_size >= 0) { + absl::StrAppend(&output, " max_array_size=", max_array_size); } if (!tensor_array_gradients.empty()) { absl::StrAppend(&output, " tensor_array_gradients=", @@ -358,7 +380,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) initialization_status_(Status::OK()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), - device_mgr_({device_}) { + device_mgr_(absl::WrapUnique(device_)) { CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = @@ -545,12 +567,12 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return Status::OK(); } case XlaResource::kTensorArray: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); @@ -562,12 +584,12 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return Status::OK(); } case XlaResource::kStack: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( @@ -613,21 +635,23 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[i]; XlaExpression& arg_expression = (*arg_expressions)[i]; switch (arg.kind) { - case XlaCompiler::Argument::kResource: + case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); // TODO(phawkins): this code assumes that resource arguments do not // alias. - XlaResource* resource; - TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), - /*tensor_array_size=*/arg.tensor_array_size, - /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); + XlaResource* resource = + context->AddResource(absl::make_unique( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::XlaOp(), + /*max_array_size=*/arg.max_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, + /*tensor_array_multiple_writes_aggregate=*/true)); arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } - break; + } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); @@ -901,9 +925,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - &options_.shape_representation_fn); + XlaContext* context = new XlaContext(this, &builder); core::ScopedUnref context_unref(context); std::vector real_args(args.begin(), args.end()); @@ -988,23 +1010,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, - &result->resource_updates)); + &result->resource_updates, &result->xla_output_shape)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - - // Compute the XLA output shape, if there is a computation with non-constant - // outputs. - TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, - client()->GetComputationShape(*result->computation)); - - result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " - << xla::ShapeUtil::HumanString(result->xla_output_shape); - - // Tensorflow expects a major-to-minor order of results. - xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - + << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 63426124686e1b92a3534b7e365b8282008b8455..0d801b73a8c2651305328384377751254ecaa41d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -150,7 +150,7 @@ class XlaCompiler { // For a TensorArray or Stack resource, what is the array's declared size? // (Used for lazy initialization.) - int64 tensor_array_size = -1; + int64 max_array_size = -1; // TensorArray resource parameters are passed as (array, gradient array 0, // ..., gradient array k), where the gradient arrays are in the same order diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aaee208f6349d56f685481977cea55c8dd5e7938..fe2a5f5b0c9ea6b5f2bb71df836fdcabf9a0cf23 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -649,7 +650,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; // Compiles the graph. @@ -708,7 +709,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -740,7 +741,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -910,6 +911,82 @@ TEST_F(XlaCompilerTest, Variables) { RunAndCheckVariablesComputation(client_, result); } +TEST_F(XlaCompilerTest, ResultLayoutSingle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET"), a, 0); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + auto compile_options = XlaCompiler::CompileOptions(); + compile_options.always_return_tuple = false; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), + args, &result)); + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); +} + +TEST_F(XlaCompilerTest, ResultLayoutMultiple) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0); + auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", + std::move(graph), args, &result)); + xla::Shape result_shape = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { Scope scope = Scope::NewRootScope().ExitOnError(); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 43095fbb47351617a0de12a088c947106ccaa641..a69af70503376b6c0905deb8980abdc3254a6e47 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -54,25 +54,14 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; return *context; } -/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); -} - void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext( - XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn) - : compiler_(compiler), - builder_(builder), - allow_cpu_custom_calls_(allow_cpu_custom_calls), - shape_representation_fn_(shape_representation_fn) {} +XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) + : compiler_(compiler), builder_(builder) {} -string XlaContext::DebugString() { return "TLA JIT context"; } +string XlaContext::DebugString() { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { if (retvals_.size() <= index) { @@ -81,21 +70,9 @@ void XlaContext::SetRetval(int index, const XlaExpression& expression) { retvals_[index] = expression; } -Status XlaContext::CreateResource( - XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, XlaResource** resource) { - resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), - handle, tensor_array_size, tensor_array_gradients, - /*tensor_array_multiple_writes_aggregate=*/false)); - *resource = resources_.back().get(); - return Status::OK(); -} - -xla::StatusOr XlaContext::RepresentationShape( - const TensorShape& shape, DataType type) const { - return (*shape_representation_fn_)(shape, type); +XlaResource* XlaContext::AddResource(std::unique_ptr resource) { + resources_.push_back(std::move(resource)); + return resources_.back().get(); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index dbfd344c9bad8a5d05abb6a3b902ed3baebbe02a..0767d1faac14cedb8666f6cc37175eb7b55f6158 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -41,14 +41,10 @@ class XlaContext : public ResourceBase { public: // Retrieves the XlaContext of the current compilation. static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx); // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. - XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn); + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. string DebugString() override; @@ -58,8 +54,6 @@ class XlaContext : public ResourceBase { // Returns the XlaBuilder that Ops use for compiling new expressions. xla::XlaBuilder* builder() { return builder_; } - bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); @@ -70,25 +64,13 @@ class XlaContext : public ResourceBase { // grows the return values vector to size index+1 if it is smaller. void SetRetval(int index, const XlaExpression& expression); - // Creates a resource with resource `kind` and initial value `handle`. `name` - // is a descriptive name for use in error messages. See the `XlaResource` - // constructor for a description of the remaining arguments. - // Fails if the resource already exists. - Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, TensorShape shape, - const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, - XlaResource** resource); + // Adds 'resource' to the set of resources owned by the context. + XlaResource* AddResource(std::unique_ptr resource); const std::vector>& resources() { return resources_; } - // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; - // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -118,9 +100,6 @@ class XlaContext : public ResourceBase { // The XlaBuilder used to construct the subgraph's compiled representation. xla::XlaBuilder* builder_; - // Allow ops to emit CustomCall operations for CPU. - const bool allow_cpu_custom_calls_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -131,11 +110,6 @@ class XlaContext : public ResourceBase { // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Describes the on-host shapes of parameters and return values. Also see: - // XlaDevice::Options::shape_representation_fn. - const std::function(const TensorShape&, DataType)>* - shape_representation_fn_; - // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da..c2c0751211180c3715a19d6c78e34659fd18914e 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" @@ -216,8 +215,7 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { return dtype; } -xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, +xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 39578144caaadf293d24ea91aa874e56e27ecc01..4858dfee55a393d04cd2af83916eeb40820ee368 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -80,8 +80,7 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, + static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type); }; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 86a78ee429e8913edb4a948727fa692083c472f4..fabbcd04fed96ad814d04c2df9394f43bfe0cf99 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile( jit->executable_ = std::move(executable); jit->buffer_infos_ = std::move(buffer_infos); jit->arg_index_table_ = std::move(arg_index_table); - jit->program_shape_ = std::move(program_shape); + jit->program_shape_ = + absl::make_unique(program_shape->ToProto()); jit->static_data_.set_raw_function(raw_function); jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index d3c8f22a8078d03d15447ed200c914390f40b04f..a5392057177e983e11787c31bb496a8947add1e6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction { std::vector arg_names_; std::vector result_names_; - // The backing data for the program shape. - std::unique_ptr program_shape_; + // The backing data for the program shape. The proto form of program shape is + // used because the program shape is serialized and embedded in the object + // file. + std::unique_ptr program_shape_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 6d49298a6f3e8a726695fafc42f3c5341fe98b5f..8846088678b53f6b9ecff0de732d6b5c82392b5a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) { // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); - const xla::ProgramShape* program_shape = function.ProgramShape(); - ASSERT_TRUE(program_shape != nullptr); - ASSERT_EQ(program_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + ASSERT_TRUE(function.ProgramShape() != nullptr); + const xla::ProgramShape program_shape(*function.ProgramShape()); + ASSERT_EQ(program_shape.parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32)); - const xla::Shape& result = program_shape->result(); + const xla::Shape& result = program_shape.result(); ASSERT_EQ(result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 8dd8def0549f2b39d4c9863bb535f19703c3ef22..58808c76de6330a6b28e21dbdead03dea25847f6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -36,8 +36,16 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } +XlaContext* XlaOpKernelContext::xla_context() const { + return &XlaContext::Get(context_); +} + xla::XlaBuilder* XlaOpKernelContext::builder() const { - return XlaContext::Get(this).builder(); + return xla_context()->builder(); +} + +XlaCompiler* XlaOpKernelContext::compiler() const { + return xla_context()->compiler(); } // Retrieves an XlaExpression that was allocated by a previous Op. @@ -338,8 +346,8 @@ Status XlaOpKernelContext::ConstantInputList( namespace { Status ReadVariableInputTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, TensorShape* shape, - xla::XlaOp* value) { + const XlaOpKernelContext* ctx, + TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); @@ -357,10 +365,9 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, *shape = variable->shape(); } - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN( - xla::Shape representation_shape, - xla_context.RepresentationShape(variable->shape(), variable->type())); + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn( + variable->shape(), variable->type())); xla::Shape xla_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); @@ -377,15 +384,15 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(context_->input(index), type, context_, shape, + return ReadVariableInputTensor(context_->input(index), type, this, shape, value); } Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, - shape, value); + return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape, + value); } Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, @@ -464,7 +471,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { namespace { Status AssignVariableTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, xla::XlaOp handle, + const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -481,9 +488,9 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, - xla_context.RepresentationShape(shape, type)); + TF_ASSIGN_OR_RETURN( + xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn(shape, type)); xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { @@ -498,19 +505,15 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(context_->input(input_index), type, context_, - handle, builder()); + return AssignVariableTensor(context_->input(input_index), type, this, handle, + builder()); } Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(GetInputTensorByName(name), type, context_, - handle, builder()); -} - -XlaCompiler* XlaOpKernelContext::compiler() const { - return XlaContext::Get(context_).compiler(); + return AssignVariableTensor(GetInputTensorByName(name), type, this, handle, + builder()); } void XlaOpKernelContext::CtxFailure(const Status& s) { @@ -530,22 +533,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMax(type); + return xla_context()->GetOrCreateMax(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMin(type); + return xla_context()->GetOrCreateMin(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { - return XlaContext::Get(context_).GetOrCreateAdd(type); + return xla_context()->GetOrCreateAdd(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMul(type); + return xla_context()->GetOrCreateMul(type); } const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index c06efa2c474c5ec3cb5d75d94ba15d4096faa085..1858844bc05a6e12abbf07af83cad816590ddd03 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -60,6 +60,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); + XlaContext* xla_context() const; + // Returns the XLA XlaBuilder containing the output of compilation. xla::XlaBuilder* builder() const; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index dcd0e9c5c1f20c07c6d2b6fd7315a861817bc523..14237df69081016817fbd1a5332f22996e7f264d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -130,8 +130,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool cpu_global_jit = flags->tf_xla_cpu_global_jit; mutex_lock lock(registry.mutex_); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index a322eb9015e829fd468133f3de6c12aad7e4ff74..48a3c012727acd8472d3d5d4072ae700f5497d96 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -39,9 +40,29 @@ namespace tensorflow { } } +/*static*/ std::unique_ptr XlaResource::CreateStack( + string name, DataType type, int64 max_size) { + return absl::make_unique( + XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(), + /*initial_value=*/xla::XlaOp(), + /*max_array_size=*/max_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + +/*static*/ std::unique_ptr XlaResource::CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size) { + return absl::make_unique( + XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape, + initial_value, max_array_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, + int64 max_array_size, const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate) : kind_(kind), @@ -51,7 +72,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, shape_(std::move(shape)), value_(initial_value), initial_value_(initial_value), - tensor_array_size_(tensor_array_size), + max_array_size_(max_array_size), tensor_array_multiple_writes_aggregate_( tensor_array_multiple_writes_aggregate) { CHECK(kind_ != kInvalid); @@ -60,7 +81,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, - xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}, + xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{}, /*tensor_array_multiple_writes_aggregate=*/true)); } } @@ -113,7 +134,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kTensorArray: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); @@ -121,7 +142,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kStack: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), @@ -146,14 +167,14 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); xla::XlaOp gradient_value = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), - type_, shape_, gradient_value, tensor_array_size_, + type_, shape_, gradient_value, max_array_size_, /*tensor_array_gradients=*/{}, /*tensor_array_multiple_writes_aggregate=*/true)); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 857b9a928bb824656f637b2b1ca2fc02a1bef139..736588bb8b89ba756cdce77eeebff8d1fcf4774c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -38,9 +38,18 @@ class XlaResource { }; static absl::string_view KindToString(Kind kind); + // Creates a new Stack resource. + static std::unique_ptr CreateStack(string name, DataType type, + int64 max_size); + + // Creates a new TensorArray resource. + static std::unique_ptr CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size); + XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, + int64 max_array_size, const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate); @@ -119,12 +128,12 @@ class XlaResource { // TODO(phawkins): refactor this code to use subclasses, rather than putting // kind-specific fields in XlaResource. - // 'tensor_array_size' stores the expected size of the TensorArray or Stack. + // 'max_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. // Used by both TensorArrays and Stacks. - int64 tensor_array_size() const { return tensor_array_size_; } - void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + int64 max_array_size() const { return max_array_size_; } + void set_max_array_size(int64 size) { max_array_size_ = size; } bool tensor_array_multiple_writes_aggregate() const { return tensor_array_multiple_writes_aggregate_; @@ -151,7 +160,7 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; - int64 tensor_array_size_ = -1; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; std::map> tensor_array_gradients_; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 91096cf1d043eb652756f77b7594780124260766..4360e0857964b0ac63fc887e269b04a4b00d854a 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -226,12 +226,14 @@ cc_library( "index_util.cc", "layout_util.cc", "primitive_util.cc", + "shape.cc", "shape_util.cc", ], hdrs = [ "index_util.h", "layout_util.h", "primitive_util.h", + "shape.h", "shape_util.h", ], visibility = ["//visibility:public"], @@ -254,6 +256,23 @@ cc_library( ], ) +tf_cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "shape_util_test", srcs = ["shape_util_test.cc"], @@ -745,6 +764,8 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 782c966b4c57672d137569a318fb20ace14d493b..e4aca98f67d50287a83afc6f41a59458f3df2da2 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -104,7 +104,7 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); - auto set = [&array, n1, n2](int64 index, NativeT value) { + auto set = [&array, n2](int64 index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64 i = 0; i < count - 1; ++i) { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 42da0ebf4992884187bbe21701a44d8ba2fccd64..fe99564d3c671cd7890e1fa26fcd2e3384972983 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -81,6 +81,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -90,11 +91,12 @@ cc_library( srcs = ["executable_build_options.cc"], hdrs = ["executable_build_options.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -191,6 +193,7 @@ cc_library( hdrs = ["xla_computation.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index eef2844e0df6aaf509881535f41493673fbeeee5..74b76f929949d3300a5d0ff45d5fa4cd9f162642 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -42,7 +43,7 @@ StatusOr Client::Transfer(const GlobalData& data, TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferToClientResponse response; @@ -123,7 +124,7 @@ StatusOr Client::TransferFromOutfeed( } request.set_replica_id(replica_id); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferFromOutfeedResponse response; @@ -170,11 +171,14 @@ StatusOr Client::ExecuteAndTransfer( std::unique_ptr data, Execute(computation, arguments, execution_options, execution_profile)); - const Shape* shape_with_output_layout = nullptr; + absl::optional shape_with_output_layout; if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); + shape_with_output_layout = + Shape(execution_options->shape_with_output_layout()); } - return Transfer(*data, shape_with_output_layout); + return Transfer(*data, shape_with_output_layout.has_value() + ? &(*shape_with_output_layout) + : nullptr); } StatusOr Client::ComputeConstant(const XlaComputation& computation, @@ -229,7 +233,7 @@ StatusOr Client::Compile( // The argument shapes affect how the computation is compiled. for (const auto& arg_shape : argument_shapes) { - *request.add_input_shape_with_layout() = arg_shape; + *request.add_input_shape_with_layout() = arg_shape.ToProto(); } CompileResponse response; @@ -458,7 +462,7 @@ StatusOr Client::GetShape(const GlobalData& data) { return s; } - return response.shape(); + return Shape(response.shape()); } StatusOr Client::ExecutionStatsAsString( diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 0f1745366b7c33e573aff2e66d85431b01488c49..1f594e551af381d7537e947892cbf7e0b5b3b861 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" namespace xla { @@ -39,6 +40,13 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } +DebugOptions* ExecutableBuildOptions::mutable_debug_options() { + if (!has_debug_options()) { + debug_options_ = GetDebugOptionsFromFlags(); + } + return &debug_options_.value(); +} + ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( const Shape& shape_with_layout) { result_layout_set_ = true; @@ -55,68 +63,10 @@ string ExecutableBuildOptions::ToString() const { if (result_layout_set_) { result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); } - string generate_hlo_graph = "nullopt"; - if (generate_hlo_graph_.has_value()) { - generate_hlo_graph = generate_hlo_graph_.value(); - } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout, generate_hlo_graph); -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( - string regex) { - generate_hlo_graph_ = std::move(regex); - return *this; -} - -const absl::optional& ExecutableBuildOptions::generate_hlo_graph() - const { - return generate_hlo_graph_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - absl::string_view dirpath) { - dump_optimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { - return dump_optimized_hlo_proto_to_; -} - -ExecutableBuildOptions& -ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath) { - dump_unoptimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { - return dump_unoptimized_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath) { - dump_per_pass_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { - return dump_per_pass_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { - hlo_profile_ = enabled; - return *this; -} - -absl::optional ExecutableBuildOptions::hlo_profile() const { - return hlo_profile_; + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 93334db88bc24f2ffbf3c7a57ee45ef238286739..a58090253bfac7779e4b61bc7231a0f0d945cc00 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -19,7 +19,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -44,6 +46,12 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; + // Expose access to the XLA debug options which will be passed to the + // compilation process. + bool has_debug_options() const { return debug_options_.has_value(); } + const DebugOptions& debug_options() const { return *debug_options_; } + DebugOptions* mutable_debug_options(); + // If set, this specifies an allocator that can be used to allocate temporary // space on the device during compilation. For example, the compiler might // want to run various algorithms on the device and pick the fastest one -- it @@ -55,56 +63,16 @@ class ExecutableBuildOptions { DeviceMemoryAllocator* allocator); DeviceMemoryAllocator* device_allocator() const; - // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). - ExecutableBuildOptions& set_generate_hlo_graph(string regex); - const absl::optional& generate_hlo_graph() const; - - // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_optimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_unoptimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs - // to (as in DebugOptions). - ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_per_pass_hlo_proto_to() const; - - // If true, specifies that we should record an HLO profile during execution - // and log it after execution (as in DebugOptions). If nullopt the default is - // used. - ExecutableBuildOptions& set_hlo_profile(bool enabled); - absl::optional hlo_profile() const; - - void add_disabled_hlo_pass(absl::string_view pass_name) { - disabled_hlo_passes_.push_back(std::string(pass_name)); - } - const absl::Span disabled_hlo_passes() const { - return disabled_hlo_passes_; - } - // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: - absl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - absl::optional generate_hlo_graph_; - absl::optional dump_optimized_hlo_proto_to_; - absl::optional dump_unoptimized_hlo_proto_to_; - absl::optional dump_per_pass_hlo_proto_to_; + absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; - std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f833ddcd3235e08e2d0d3c0b9921e96ef871c89e..41db8de29ff0085a30847ff41db4ffbfc774e2a1 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -104,13 +104,17 @@ xla_test( ) cc_library( - name = "numeric", - srcs = ["numeric.cc"], - hdrs = ["numeric.h"], + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", @@ -118,11 +122,12 @@ cc_library( ) xla_test( - name = "numeric_test", - srcs = ["numeric_test.cc"], + name = "matrix_test", + srcs = ["matrix_test.cc"], tags = ["enable_for_xla_interpreter"], deps = [ - ":numeric", + ":matrix", + ":slicing", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -164,7 +169,6 @@ cc_library( deps = [ ":constants", ":math", - ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", @@ -173,13 +177,46 @@ cc_library( ], ) +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "sorting", srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ - ":numeric", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", ], @@ -188,10 +225,6 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", @@ -225,3 +258,48 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 08a887a6e4660cb2528f0ec7244b7ccc540808d2..36fdda39b4124b9100c6054160f9c17bdf787d6f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -268,17 +268,16 @@ XlaOp Digamma(XlaOp input) { // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. XlaOp RoundToEven(XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); - auto round_val = xla::Floor(x); + auto round_val = Floor(x); auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); } // Trigonometric functions. @@ -320,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == C64 && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 3f06d04b9ae98b3aa75e68cd07810b2b4c24d280..17612bf9fdc0f1eabb338671c93c025c5b268872 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -86,6 +86,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if `a` is complex and `conjugate` +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffd744d190885b8e3f4149a48a706498b3787618 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + return InvalidArgument( + "Arguments to BatchDot have different ranks: %s vs. %s", + ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + } + const int ndims = ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to BatchDot must have rank >= 2: got %d", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return InvalidArgument( + "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", + i, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = ndims - 1; + int y_inner_dim = ndims - 2; + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return InvalidArgument( + "Dimensions %d and %d of arguments to BatchDot must be equal: " + "shapes %s vs %s", + x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + + // Check for zero lhs/rhs dim size. + if (ShapeUtil::IsZeroElementArray(x_shape) || + ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = ndims - 2; + int y_outer_dim = ndims - 1; + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return Broadcast( + ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + + return DotGeneral(x, y, dot_dnums, &precision_proto); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/matrix.h similarity index 56% rename from tensorflow/compiler/xla/client/lib/numeric.h rename to tensorflow/compiler/xla/client/lib/matrix.h index efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1..8856f99c7a0fee8f315aac11fab392cf5536f57b 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" @@ -22,9 +22,6 @@ limitations under the License. namespace xla { -// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); - // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); @@ -43,6 +40,34 @@ XlaOp UpperTriangle(XlaOp x); // Get the lower triangle part of the last two dimensions XlaOp LowerTriangle(XlaOp x); +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc similarity index 53% rename from tensorflow/compiler/xla/client/lib/numeric_test.cc rename to tensorflow/compiler/xla/client/lib/matrix_test.cc index 7d6aedd49462bd4f075f90d0b0f85c40f1191aa1..0593a7517ac125ca8dc5395cee76f6bc23232cd3 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -24,13 +26,13 @@ limitations under the License. namespace xla { namespace { -class NumericTest : public ClientLibraryTestBase { +class MatrixTest : public ClientLibraryTestBase { protected: template void TestMatrixDiagonal(); }; -XLA_TEST_F(NumericTest, Triangle) { +XLA_TEST_F(MatrixTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); input.FillIota(0); @@ -45,7 +47,7 @@ XLA_TEST_F(NumericTest, Triangle) { } template -void NumericTest::TestMatrixDiagonal() { +void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); @@ -58,11 +60,46 @@ void NumericTest::TestMatrixDiagonal() { ComputeAndCompareR2(&builder, expected, {a_data.get()}); } -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); + + int n = 4; -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc deleted file mode 100644 index 377654220b5df4487e9e194361473d54ff46a54e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" - -namespace xla { - -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, - int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); - auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); - return ConvertElementType(indicator, type); -} - -XlaOp GetMatrixDiagonal(XlaOp x) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); - - // TPUs don't support S64 add reduction at the moment. But fortunately - // OR-reductions work just as well for integers. - XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); - }); -} - -XlaOp Triangle(XlaOp x, bool lower) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - xla::XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); - }); -} - -XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } - -XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index c6f68c8ee2f5198017c37abeb9551478f52a99f4..85b9e1827dcef5ed907d893277deb5a52f8f30e9 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8c7df3ff5189c817202eaf39adb572f7e232ec2 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return DynamicUpdateSlice(x, update, start_constant); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + auto zero = Reshape(ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); + } + return ConcatInDim(builder, padded_starts, 0); + }); +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return DynamicUpdateSlice(x, update, padded_starts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 0000000000000000000000000000000000000000..6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc similarity index 67% rename from tensorflow/compiler/tf2xla/lib/util_test.cc rename to tensorflow/compiler/xla/client/lib/slicing_test.cc index 442fe92c34ca26cb1a854cc90da8dc034bca79bb..8d362119e01006555db0f82d02626175936e1d05 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -13,28 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; +using SlicingTest = xla::ClientLibraryTestBase; xla::Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; @@ -63,7 +54,7 @@ xla::Array3D BatchedAValsFull() { }}; } -XLA_TEST_F(UtilTest, Simple2dLookup) { +XLA_TEST_F(SlicingTest, Simple2dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, x, y; @@ -77,7 +68,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(UtilTest, Simple3dLookup) { +XLA_TEST_F(SlicingTest, Simple3dLookup) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, index; @@ -92,7 +83,7 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { {a_data.get(), index_data.get()}); } -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { xla::XlaBuilder builder(TestName()); xla::XlaOp a, b, x, y; @@ -111,26 +102,5 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 0475fd9c94f6e390b5169cfe2cbba8eae28ddc18..e8553a08bb014e790822a14e128686b60b8d6b7c 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -23,13 +25,12 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; - int last_dim_size = input_shape.dimensions(last_dim); - XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + Shape iota_shape = + ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); + XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), {broadcast_s32}); + XlaOp sort_result = Sort(Neg(input), {iota_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index fef98c9923096e21a755c6d730de2c7c10852b2d..27ff36c7491ab8397d46f3a49493ff2b904deb2d 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" + +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -41,6 +44,28 @@ XLA_TEST_F(SortingTest, TopK3From8Indices) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } +// TODO(b/119930279): enable this test. +XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {2, 1, 0}, {}); +} + +XLA_TEST_F(SortingTest, NOT_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + // TopK currently negates the keys, which doesn't work correctly for + // std::numeric_limits::min(). Therefore, it will sort this key to the + // front instead of to the back. + ComputeAndCompareR1(&builder, {0, 2, 1}, {}); +} + XLA_TEST_F(SortingTest, TopKFullSort) { XlaBuilder builder(TestName()); const int kSize = 16; @@ -56,5 +81,13 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { + XlaBuilder builder(TestName()); + XlaOp a; + auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); + xla::GetTupleElement(xla::TopK(a, 5), 1); + ComputeAndCompareR1(&builder, {2, 3, 0, 1, 4}, {a_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index a44681f586278bf03f3fb2b8c812936cbf3ad47b..a95bbf2c8c860914877d3195b97342097dafc725 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -66,7 +66,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); - *execution_options.mutable_shape_with_output_layout() = shape; + *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } @@ -98,8 +98,8 @@ std::vector> MakeFakeArgumentsOrDie( auto program_shape = computation.proto().host_program_shape(); std::vector> results; - for (const Shape& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(shape, client)); + for (const ShapeProto& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(Shape(shape), client)); } return results; } diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc similarity index 62% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve.cc index 6524c2a9b1ada632d80edd234272760c2b545cc4..c5a1d34cc66e6f8c1a832f8a8437163b846a5431 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -29,21 +29,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" -namespace tensorflow { +namespace xla { // Get the diagonal blocks of the coefficient matrix -xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); - int ndims = xla::ShapeUtil::Rank(shape); - int64 n = xla::ShapeUtil::GetDimension(shape, -1); +XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + int ndims = ShapeUtil::Rank(shape); + int64 n = ShapeUtil::GetDimension(shape, -1); int64 num_blocks = n / block_size; - xla::XlaOp diag_blocks; + XlaOp diag_blocks; // If the coefficient matrix is exactly the block size, we just add a // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] @@ -58,13 +57,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { if (n > block_size) { // Construct the starting indices of the diagonal blocks auto start_indices = - Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), - xla::ConstantR0(builder, block_size)), + Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks), + ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), /*permutation=*/{1, 0}); // Gather the diagonal blocks - xla::GatherDimensionNumbers dim_numbers; + GatherDimensionNumbers dim_numbers; dim_numbers.add_offset_dims(ndims - 1); dim_numbers.add_offset_dims(ndims); dim_numbers.add_start_index_map(ndims - 2); @@ -80,7 +79,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Pad with zeros auto last_blocks = SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); - xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + PaddingConfig config = MakeNoPaddingConfig(ndims); int64 padding = block_size - n % block_size; config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); @@ -89,9 +88,8 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Add a singleton dimension // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(last_blocks)); - auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + auto shape_dims = AsInt64Slice(blocks_shape.dimensions()); auto last_blocks_dims = std::vector(ndims); std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); @@ -100,7 +98,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Concatenate with the other blocks if necessary if (n > block_size) { diag_blocks = - xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); } else { diag_blocks = last_blocks; } @@ -110,16 +108,16 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { +XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, + bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / tensorflow::MathUtil::IPow(block_size, 2); diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); @@ -131,9 +129,9 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, // zero (which can happen if the last block was padded) otherwise it will // introduce nans which will propagate auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); @@ -159,40 +157,40 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto start_index = (lower) ? 0 : block_size - 1; auto output_block = DynamicUpdateSlice( neg_identity, pos_one, - /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + /*start_indices=*/ConstantR1(builder, 2, start_index)); // Broadcast diag([1, -1, -1, ...]) to every block - xla::XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); // Now we construct a loop that performs matrix-vector multiplications // inverting the blocks one row at a time - std::vector tuple_shapes = { + std::vector tuple_shapes = { // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + ShapeUtil::MakeShape(S32, {}), // The output has the shape of A, with one row updated each iteration. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), // The input is a loop invariant. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = One(builder, xla::S32); - auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); // Construct the loop condition function. - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("InvertDiagCond"); { auto i = GetTupleElement( Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, xla::ConstantR0(condb.get(), block_size)); + Lt(i, ConstantR0(condb.get(), block_size)); } TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); // Construct the loop body function. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("InvertDiagBody"); { auto input_tuple = @@ -202,21 +200,21 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, auto body_out = GetTupleElement(input_tuple, 1); auto body_input = GetTupleElement(input_tuple, 2); - auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto zero = ConstantR1(bodyb.get(), 1, 0); auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; auto start_indices = - xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); auto input_row = DynamicSlice(body_input, start_indices, /*slice_sizes=*/{num_blocks, 1, block_size}); // We want -L21 L11^{-1} - xla::DotDimensionNumbers dnums; + DotDimensionNumbers dnums; dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfig precision_proto; + PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -224,7 +222,7 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, body_out = DynamicUpdateSlice(body_out, update, start_indices); auto next_i = i + ScalarLike(i, 1); - xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + Tuple(bodyb.get(), {next_i, body_out, body_input}); } TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); @@ -238,27 +236,26 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, /*broadcast_dimensions=*/{0, 1}); // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks( - xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); - - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - int64 ndims = xla::ShapeUtil::Rank(a_shape); - int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); +XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + int64 ndims = ShapeUtil::Rank(a_shape); + int64 n = ShapeUtil::GetDimension(a_shape, -1); int64 num_blocks = n / block_size + (n % block_size != 0); int64 m_dim = (left_side) ? -1 : -2; - int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + int64 m = ShapeUtil::GetDimension(b_shape, m_dim); // Initialize the solution auto x = ZerosLike(b); @@ -294,7 +291,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( } auto b_row = SliceInMinorDims(b, start, end); - xla::XlaOp remainder; + XlaOp remainder; if (i == 0) { remainder = b_row; } else { @@ -311,29 +308,27 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); } } - xla::XlaOp x_update; - auto zero = Zero(builder, xla::S32); - auto start_index = - xla::ConstantR0WithType(builder, xla::S32, j * block_size); - std::vector update_starts = {start_index, zero}; + XlaOp x_update; + auto zero = Zero(builder, S32); + auto start_index = ConstantR0WithType(builder, S32, j * block_size); + std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -343,24 +338,24 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( }); } -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + return InvalidArgument( + "Arguments to TriangularSolve have shapes with different ranks: " + "%s vs. %s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } - const int64 ndims = xla::ShapeUtil::Rank(a_shape); + const int64 ndims = ShapeUtil::Rank(a_shape); if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); + return InvalidArgument( + "Arguments to TriangularSolve was rank %d but must have rank >= 2.", + ndims); } // The batch dimensions must be equal. std::vector batch_dimensions; @@ -368,32 +363,33 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, int64 a_size = a_shape.dimensions(i); int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + return InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal; " + "shapes were %s and %s.", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); + if (ShapeUtil::GetDimension(a_shape, -1) != + ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must be a batched square matrix;" + " shape was: %s", + ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); + const int64 m = ShapeUtil::GetDimension(b_shape, -2); + const int64 n = ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) { + return InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes %s and " + "%s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", + return InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got %d", block_size); } @@ -413,4 +409,4 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, }); } -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h similarity index 88% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.h rename to tensorflow/compiler/xla/client/lib/triangular_solve.h index 2303234f361e54cd2a0ad495cb03b371bed76877..50a3b30ebd1c15eb6d2ace4e351cb41f21db7093 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -57,11 +57,11 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve( - xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, +XlaOp TriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc similarity index 99% rename from tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index aeebf16028d40189203cdfd815f06a339ee72902..f6a70d64a788d95a456774ccbbcf67f2e5cac98b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; @@ -330,4 +330,4 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index f96b6c9c261a9686fb647e3da0dcc933cd1f70df..049cd15738a619294b19d5cf74ca514d7b4a00ad 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( + ShapeUtil::HumanStringWithLayout( computation_layout.parameter_layout(i).shape()), - ShapeUtil::HumanString(arguments[i]->on_host_shape())); + ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape())); } } @@ -310,4 +310,28 @@ StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } +StatusOr LocalClient::TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal) { + const ::xla::Shape& shape = literal.shape(); + + TF_ASSIGN_OR_RETURN( + ::xla::ScopedShapedBuffer shaped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + shape, backend().memory_allocator(), device_oridinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_oridinal)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.get(), literal, shaped_buffer)); + std::vector<::xla::ScopedShapedBuffer> replicated_buffer; + replicated_buffer.emplace_back(std::move(shaped_buffer)); + ::xla::TransferToServerResponse result; + TF_ASSIGN_OR_RETURN(*result.mutable_data(), + local_service_->RegisterReplicatedBuffers( + std::move(replicated_buffer), + absl::StrCat("TransferToServer literal of shape ", + ::xla::ShapeUtil::HumanString(shape)))); + + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index e49451ca9708ab506d11af5f9855db245674864c..ddb36680e8b185b053368baffa6f1d5cac50dc07 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -129,6 +129,10 @@ class LocalClient : public Client { const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); + // Transfer the BorrowingLiteral to the device with the given ordinal. + StatusOr TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal); + // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index 176802b33ef824a1f898255a19e44def3c1fc982..fb9ea6ec3fc41d5e04ca125798a8199350470a44 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); std::vector dimensions(1, num_tiles); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = (*result.mutable_tile_shape()->mutable_dimensions())[0]; tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 0a587725d20507555382ef0657bdc08369a7fbac..60df2ec3959216b0564846ad47c21c5bcc01ea57 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -102,7 +102,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); - return instr->shape(); + return Shape(instr->shape()); } StatusOr> XlaBuilder::GetOperandShapes( @@ -155,7 +155,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { ProgramShape program_shape; - *program_shape.mutable_result() = root_proto->shape(); + *program_shape.mutable_result() = Shape(root_proto->shape()); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -172,7 +172,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { const int64 index = instr.parameter_number(); TF_RET_CHECK(index >= 0 && index < param_count) << "invalid parameter number: " << index; - *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameters(index) = Shape(instr.shape()); *program_shape.mutable_parameter_names(index) = instr.name(); } } @@ -239,6 +239,19 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, visited->insert(op_handle); } +Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, + int64 target_dim_num) { + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( + DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, + dynamic_size_param_index}, + DynamicParameterBinding::DynamicDimension{ + target_param_num, target_param_index, target_dim_num})); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -275,7 +288,8 @@ StatusOr XlaBuilder::Build(int64 root_id) { HloComputationProto entry; SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); + *entry.mutable_program_shape() = program_shape.ToProto(); entry.set_root_id(root_id); for (auto& instruction : instructions_) { @@ -297,6 +311,9 @@ StatusOr XlaBuilder::Build(int64 root_id) { } module->add_computations()->Swap(&entry); + *(module->mutable_dynamic_parameter_binding()) = + dynamic_parameter_binding_.ToProto(); + // Clear data held by this builder. this->instructions_.clear(); this->handle_to_index_.clear(); @@ -312,7 +329,7 @@ StatusOr XlaBuilder::InDimBroadcast( TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : broadcast_dimensions) { instr.add_dimensions(dim); } @@ -363,8 +380,9 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferUnaryOpShape(unop, operand_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), unop, {operand}); }); } @@ -375,9 +393,10 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); + *instr.mutable_shape() = shape.ToProto(); const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); @@ -391,7 +410,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : instr.shape().dimensions()) { + for (int64 size : shape.dimensions()) { to_size.push_back(size); } for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); @@ -411,14 +430,14 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, } TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), updated_lhs)); + AddBroadcastSequence(shape, updated_lhs)); } TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), updated_rhs)); + AddBroadcastSequence(shape, updated_rhs)); } return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); @@ -432,30 +451,28 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferTernaryOpShape( - triop, lhs_shape, rhs_shape, ehs_shape)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape, + rhs_shape, ehs_shape)); + *instr.mutable_shape() = shape.ToProto(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(shape)) { if (!ShapeUtil::IsTuple(lhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), lhs)); + TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } if (!ShapeUtil::IsTuple(rhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), rhs)); + TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } if (!ShapeUtil::IsTuple(ehs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_ehs, - AddBroadcastSequence(instr.shape(), ehs)); + TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); } } return AddInstruction(std::move(instr), triop, @@ -476,7 +493,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = literal.shape(); + *instr.mutable_shape() = literal.shape().ToProto(); *instr.mutable_literal() = literal.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConstant); }); @@ -485,7 +502,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(iota_dimension); return AddInstruction(std::move(instr), HloOpcode::kIota); }); @@ -505,10 +522,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCallShape(operand_shape_ptrs, - /*to_apply=*/called_program_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape( + operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); @@ -526,7 +543,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, } instr.set_parameter_number(parameter_number); instr.set_name(name); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kParameter); }); } @@ -556,27 +573,35 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, } XlaOp XlaBuilder::BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(operand_shape, shape, - broadcast_dimensions) + // Output shape, in the case of degenerate broadcast, the out_dim_size is + // not necessarily the same as the dimension sizes of the output shape. + const auto& output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + + TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( + operand_shape, output_shape, broadcast_dimensions) .status()); - std::vector in_dim_size(ShapeUtil::Rank(shape)); - absl::c_copy(shape.dimensions(), in_dim_size.begin()); + std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); for (int i = 0; i < broadcast_dimensions.size(); i++) { in_dim_size[broadcast_dimensions[i]] = operand_shape.dimensions(i); } const auto& in_dim_shape = - ShapeUtil::MakeShape(shape.element_type(), in_dim_size); + ShapeUtil::MakeShape(operand_shape.element_type(), in_dim_size); TF_ASSIGN_OR_RETURN( XlaOp in_dim_broadcast, InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); - if (ShapeUtil::Equal(in_dim_shape, shape)) { + + // If broadcast is not degenerate, return broadcasted result. + if (ShapeUtil::Equal(in_dim_shape, output_shape)) { return in_dim_broadcast; } - return AddBroadcastSequence(shape, in_dim_broadcast); + + // Otherwise handle degenerate broadcast case. + return AddBroadcastSequence(output_shape, in_dim_broadcast); }); } @@ -584,7 +609,7 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); } @@ -596,9 +621,9 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferSliceShape(operand_shape, start_indices, - limit_indices, strides)); + Shape shape, ShapeInference::InferSliceShape( + operand_shape, start_indices, limit_indices, strides)); + *instr.mutable_shape() = shape.ToProto(); for (int i = 0; i < start_indices.size(); i++) { auto* slice_config = instr.add_slice_dimensions(); slice_config->set_start(start_indices[i]); @@ -633,9 +658,10 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( operand_shape, start_indices_shape, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { instr.add_dynamic_slice_sizes(size); @@ -655,9 +681,10 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( operand_shape, update_shape, start_indices_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, {operand, update, start_indices}); @@ -673,9 +700,9 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span operands, TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( + operand_shape_ptrs, dimension)); + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); @@ -692,10 +719,9 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape, GetShape(padding_value)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferPadShape(operand_shape, padding_value_shape, - padding_config)); - + Shape shape, ShapeInference::InferPadShape( + operand_shape, padding_value_shape, padding_config)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_padding_config() = padding_config; return AddInstruction(std::move(instr), HloOpcode::kPad, @@ -708,7 +734,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& shape, + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); XlaOp transposed = IsIdentityPermutation(dimensions) @@ -721,7 +747,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -771,7 +797,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); @@ -797,9 +823,10 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); }); } @@ -814,7 +841,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = - ShapeUtil::GetTupleElementShape(tuple_shape, index); + ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto(); instr.set_tuple_index(index); @@ -873,9 +900,10 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_dot_dimension_numbers() = dimension_numbers; if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -1017,10 +1045,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, feature_group_count, instr.window(), dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); @@ -1093,10 +1122,9 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); - + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape( + operand_shape, fft_type, fft_length)); + *instr.mutable_shape() = shape.ToProto(); instr.set_fft_type(fft_type); for (int64 i : fft_length) { instr.add_fft_length(i); @@ -1114,7 +1142,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1135,7 +1163,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { XlaOp token; auto make_token = [&]() { HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); }; if (sharding()) { @@ -1174,7 +1202,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto infeed_data; - *infeed_data.mutable_shape() = shape; + *infeed_data.mutable_shape() = shape.ToProto(); infeed_data.set_tuple_index(0); return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, {infeed}); @@ -1190,7 +1218,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1215,7 +1243,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1228,14 +1256,14 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); // Outfeed takes a token as its second operand. Generate the token to pass // to the outfeed. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -1249,7 +1277,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto tuple_instr; - *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); // The dummy tuple should have no sharding. { @@ -1268,7 +1296,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1281,7 +1309,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); @@ -1293,7 +1321,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp XlaBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll); }); } @@ -1303,8 +1331,17 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); } + for (int i = 0; i < tokens.size(); ++i) { + const XlaOp& operand = tokens[i]; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::IsToken(operand_shape)) { + return InvalidArgument( + "All operands to AfterAll must be tokens; operand %d has shape %s", + i, ShapeUtil::HumanString(operand_shape)); + } + } HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); }); } @@ -1321,7 +1358,7 @@ XlaOp XlaBuilder::CustomCall( "are reserved for internal use.", call_target_name); } - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); if (operand_shapes_with_layout.has_value()) { @@ -1345,7 +1382,7 @@ XlaOp XlaBuilder::CustomCall( "constrained layout.", operand_num); } - *instr.add_operand_shapes_with_layout() = operand_shape; + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); ++operand_num; } } @@ -1499,9 +1536,9 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferTransposeShape(operand_shape, permutation)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape( + operand_shape, permutation)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : permutation) { instr.add_dimensions(dim); } @@ -1514,9 +1551,9 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReverseShape(operand_shape, dimensions)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape( + operand_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions) { instr.add_dimensions(dim); } @@ -1535,9 +1572,9 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, GetOperandShapes(values)); absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferVariadicOpShape( - HloOpcode::kSort, operand_shape_ptrs)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); dimension = ShapeUtil::Rank(keys_shape) - 1; @@ -1559,9 +1596,9 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); }); } @@ -1571,9 +1608,9 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, {operand}); }); @@ -1605,11 +1642,11 @@ XlaOp XlaBuilder::Map(absl::Span operands, TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, - dimensions)); + Shape shape, ShapeInference::InferMapShape( + operand_shape_ptrs, called_program_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); - const Shape& output_shape = instr.shape(); + Shape output_shape(instr.shape()); const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); @@ -1652,7 +1689,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_distribution(distribution); @@ -1680,10 +1717,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, condition.GetProgramShape()); TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferWhileShape(condition_program_shape, - body_program_shape, init_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape( + condition_program_shape, + body_program_shape, init_shape)); + *instr.mutable_shape() = shape.ToProto(); // Body comes before condition computation in the vector. AddCalledComputation(body, &instr); AddCalledComputation(condition, &instr); @@ -1700,10 +1737,10 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, start_indices_shape, + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape( + input_shape, start_indices_shape, dimension_numbers, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_gather_dimension_numbers() = dimension_numbers; for (int64 bound : slice_sizes) { @@ -1728,10 +1765,11 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, update_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferScatterShape( input_shape, scatter_indices_shape, updates_shape, to_apply_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_scatter_dimension_numbers() = dimension_numbers; @@ -1758,10 +1796,11 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, false_computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferConditionalShape( predicate_shape, true_operand_shape, false_operand_shape, true_computation_shape, false_computation_shape)); + *instr.mutable_shape() = shape.ToProto(); // The index of true_computation must be 0 and that of false computation // must be 1. @@ -1803,9 +1842,10 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferReduceShape( operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1868,10 +1908,10 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( MakeWindow(window_dimensions, window_strides, padding, /*lhs_dilation=*/base_dilations, /*rhs_dilation=*/window_dilations)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReduceWindowShape(operand_shape, init_shape, - instr.window(), to_apply_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( + operand_shape, init_shape, + instr.window(), to_apply_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, @@ -1889,9 +1929,10 @@ XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferBatchNormTrainingShape( operand_shape, scale_shape, offset_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1913,10 +1954,11 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean)); TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferBatchNormInferenceShape( - operand_shape, scale_shape, offset_shape, - mean_shape, variance_shape, feature_index)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferBatchNormInferenceShape( + operand_shape, scale_shape, offset_shape, mean_shape, + variance_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1938,10 +1980,11 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean)); TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var)); TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBatchNormGradShape( operand_shape, scale_shape, batch_mean_shape, batch_var_shape, grad_output_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1972,9 +2015,9 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( + {&operand_shape})); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; @@ -2027,8 +2070,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; } @@ -2053,8 +2096,9 @@ XlaOp XlaBuilder::CollectivePermute( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); HloInstructionProto instr; TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferCollectivePermuteShape(operand_shape)); + *instr.mutable_shape() = shape.ToProto(); for (const auto& pair : source_target_pairs) { auto* proto_pair = instr.add_source_target_pairs(); @@ -2103,10 +2147,11 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSelectAndScatterShape( operand_shape, select_shape, instr.window(), source_shape, init_shape, scatter_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(select, &instr); AddCalledComputation(scatter, &instr); @@ -2121,9 +2166,10 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReducePrecisionShape( operand_shape, exponent_bits, mantissa_bits)); + *instr.mutable_shape() = shape.ToProto(); instr.set_exponent_bits(exponent_bits); instr.set_mantissa_bits(mantissa_bits); return AddInstruction(std::move(instr), HloOpcode::kReducePrecision, @@ -2138,7 +2184,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2157,15 +2203,17 @@ XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, // token}. HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp send, AddInstruction(std::move(send_instr), HloOpcode::kSend, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -2179,7 +2227,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2190,7 +2238,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto recv_data; - *recv_data.mutable_shape() = shape; + *recv_data.mutable_shape() = shape.ToProto(); recv_data.set_tuple_index(0); return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, {recv}); @@ -2207,15 +2255,18 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); @@ -2249,9 +2300,11 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, // Send instruction produces a tuple of {aliased operand, U32 context, // token}. HloInstructionProto send_instr; - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape_with_layout, ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape_with_layout, + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); send_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp send, @@ -2259,7 +2312,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); send_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, @@ -2288,8 +2341,10 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); recv_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), @@ -2297,7 +2352,8 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, @@ -2309,9 +2365,9 @@ XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( + operand_shape, dimension)); + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, {operand}); @@ -2356,7 +2412,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator, GetNextId()); entry.set_root_id(root->id()); - ProgramShape* program_shape = entry.mutable_program_shape(); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is @@ -2617,9 +2673,10 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { - return operand.builder()->BroadcastInDim(operand, shape, + return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 68314a026eab0db3eaf321f0fa53c016d79882ba..098efb60f9bdca8306ff771a505f4a225dea9f7d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -263,35 +264,30 @@ class XlaBuilder { // evaluating the computation. StatusOr IsConstant(const XlaOp& operand) const; + // Sets up binding which indicates that the `target_dim_num` in the subshape + // `target_param_index` of parameter `target_param_num` is a dynamic dimension + // and its real dynamic size is represented by `dynamic_param_index` in + // parameter `dynamic_param_num`. + // + // TODO(b/119520625): Remove this API once we have more dynamic shape infra + // ready. + Status SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, int64 target_dim_num); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. + // Description for the methods below can be found in the corresponding public + // functions section in this file. + XlaOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - // Enqueues a constant with the value of the given literal onto the - // computation. XlaOp ConstantLiteral(const LiteralSlice& literal); - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. template XlaOp ConstantR0(NativeT value); template @@ -321,181 +317,79 @@ class XlaBuilder { template XlaOp ConstantR4FromArray4D(const Array4D& values); - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. template XlaOp ConstantR1(int64 length, NativeT value); - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); - XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, absl::Span new_sizes); - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice XlaOp Slice(const XlaOp& operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. XlaOp ConcatInDim(absl::Span operands, int64 dimension); - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. void Trace(const string& tag, const XlaOp& operand); - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - // Enqueues a tuple-creation instruction onto the computation. XlaOp Tuple(absl::Span elements); - // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); - // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -503,8 +397,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, @@ -512,8 +404,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -521,8 +411,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -532,80 +420,53 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config = ""); - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const string& outfeed_config); - // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, absl::Span operands); - // Enqueues a custom call instruction onto the computation. XlaOp CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions = {}); - // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); - // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -624,32 +485,23 @@ class XlaBuilder { XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Reduces several arrays simultaneously among the provided dimensions, given - // "computation" as a reduction operator. XlaOp Reduce(absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, @@ -659,48 +511,22 @@ class XlaBuilder { absl::Span window_dilations, absl::Span> padding); - // Returns the sum of the operand value within each subgroup of replicas. All - // replicas supply one input to the sum and all replicas receive the resulting - // sum for each subgroup. XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups = {}); - // Enqueues an operation that do an AllReduce of the operand cross cores. Here - // AllReduce means doing a reduction on the input operand cross cores and then - // broadcasting the reduction result to those cores. The reduction function is - // defined by `computation`, which should be a commutative computation on - // scalars, e.g., add, min, or max. The way that AllReduce is applied is - // configured by: - // - // - `replica_groups`: each ReplicaGroup contains a list of replica id. If - // empty, all replicas belong to one group. Allreduce will be applied within - // subgroups. For example, we have 4 replicas, then - // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, - // replica 1 and 3 are in subgroup 1. - // - // - `channel_id`: for Allreduce nodes from different modules, if they have - // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will - // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); - // Enqueues an operation that do an Alltoall of the operand cross cores. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - // Enqueues an operation that do an CollectivePermute of the operand cross - // cores. XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -708,8 +534,6 @@ class XlaBuilder { const XlaOp& init_value, const XlaComputation& scatter); - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -717,217 +541,119 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); - // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, absl::Span broadcast_dimensions = {}); - // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); - // Enqueues an expm1 instruction onto the computation. XlaOp Expm1(const XlaOp& operand); - // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); - // Enqueues a ceil instruction onto the computation. XlaOp Ceil(const XlaOp& operand); - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. XlaOp Round(const XlaOp& operand); - // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); - // Enqueues an log1p instruction (log(x+1)) onto the computation. XlaOp Log1p(const XlaOp& operand); - // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); - // Enqueues a count leading zeros instruction onto the computation. XlaOp Clz(const XlaOp& operand); - // Enqueues a cosine instruction onto the computation. XlaOp Cos(const XlaOp& operand); - // Enqueues a sine instruction onto the computation. XlaOp Sin(const XlaOp& operand); - // Enqueues a tanh instruction onto the computation. XlaOp Tanh(const XlaOp& operand); - // Enqueues a real-part instruction onto the computation. XlaOp Real(const XlaOp& operand); - // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. XlaOp IsFinite(const XlaOp& operand); - // Enqueues an iota operation onto the computation. XlaOp Iota(const Shape& shape, int64 iota_dimension); - // Enqueues a rank-1 iota operation onto the computation. XlaOp Iota(PrimitiveType type, int64 size); - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); - // Enqueues a transpose instruction onto the computation. XlaOp Transpose(const XlaOp& operand, absl::Span permutation); - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - // Enqueues a sort (as increasing order) instruction onto the computation. - // If only keys are provided: - // * If the keys are an rank-1 tensor (an array), the result is a sorted array - // of keys, in ascending order. - // * If the keys have higher rank, the keys are sorted along the provided - // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension - // value of 0 will indepenently sort every column, and a dimension value of 1 - // will independently sort each row. If no dimension number is provided, then - // the last dimension is chosen by default. - // - // If both keys and values are provided: - // * The keys and all values must be tensors with the same dimensions. The - // element types of the tensors may be different. - // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and tensors with their - // corresponding values as the other elements. XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); - // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - // Enqueues a map instruction onto the computation. XlaOp Map(absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands = {}); - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); - // Enqueues a while node onto the computation. XlaOp While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init); - // Enqueues a conditional node onto the computation. XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation); - // Enqueues a ReducePrecision node onto the computation. XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); - // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - // Enqueues a Send node onto the computation for device-to-device - // communication, to send the given operand to a Recv instruction that shares - // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); - // Enqueues a Send node which sends data to the host. XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const ChannelHandle& handle); - // Enqueues a Recv node which receives data from the host. XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp CreateToken(); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp AfterAll(absl::Span tokens); - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index); - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, @@ -1019,6 +745,9 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -1096,7 +825,7 @@ class XlaBuilder { absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, @@ -1393,6 +1122,7 @@ class XlaScopedShardingAssignment { // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. +// // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. @@ -1488,7 +1218,8 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // will generate output // {{1 , 1}, // {2 , 2}} -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on @@ -2138,6 +1869,7 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); // Implementation details below this point. +// template XlaOp XlaBuilder::ConstantR0(NativeT value) { diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 8aa85c3cd63c9b0aeb55d2cebbb989b6432ac959..b3f5be300d3f15397ad33858a6a9cab5f6029688 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -267,7 +267,7 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { TEST_F(XlaBuilderTest, BroadcastInDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); - BroadcastInDim(x, ShapeUtil::MakeShape(F32, {2, 4, 3}), + BroadcastInDim(x, {2, 4, 3}, /*broadcast_dimensions=*/{0, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); auto root = module->entry_computation()->root_instruction(); @@ -277,7 +277,7 @@ TEST_F(XlaBuilderTest, BroadcastInDim) { TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); - BroadcastInDim(x, ShapeUtil::MakeShape(F32, {2, 3, 4}), + BroadcastInDim(x, {2, 3, 4}, /*broadcast_dimensions=*/{0, 1, 2}); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -446,5 +446,14 @@ TEST_F(XlaBuilderTest, ProtoMatches) { EXPECT_EQ(c0_string, c1_string); } +TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { + XlaBuilder b(TestName()); + AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); + Status status = b.Build().status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("All operands to AfterAll must be tokens")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index c9870b65b91c1ebd7d44143faf215a2d5c2a2fc5..f317892c12529b2ee8a81788f6bbcae3b3d6489d 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -25,7 +25,7 @@ namespace xla { StatusOr XlaComputation::GetProgramShape() const { TF_RET_CHECK(proto_.has_host_program_shape()); - return proto_.host_program_shape(); + return ProgramShape(proto_.host_program_shape()); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 71598ef8b296a760b0ee818fce0a59aed5cfc6b4..3ccbfb28bd0c5939ee40878e9cc298688882ac62 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 033887d7c11bb530d70f0653f26c61bcbfe1e321..20609cad58d920c0c272899c41efeb99d23cd490 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -54,7 +54,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_min_max(true); flags->set_xla_force_host_platform_device_count(1); } @@ -160,11 +160,11 @@ void AllocateFlags() { "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( - "xla_gpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the GPU compiler; " - "this may produce faster code at the expense of some accuracy."), + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for( @@ -334,8 +334,14 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); - ParseFlagsFromEnv(*flag_objects); + ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } } // namespace diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index fb135f5ceda67ce6c001de15b8f3f084ca164826..1fea816a803bfb75b9721393cef8c4dfc249268d 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.core.framework import attr_value_pb2 @@ -64,22 +61,18 @@ class Sharding(object): tile_assignment_devices=[core])) @classmethod - def tile(cls, tile_shape, tile_assignment): + def tile(cls, tile_assignment): """Returns a Tiled sharding attribute. This causes an op to be partially computed on multiple cores in the XLA device. Args: - tile_shape: A xla_shape.Shape describing the tile shape that each core - will compute. - The tile shape does not need to be divisible by the tile assignment. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: - TypeError: tile_assignment was not of np.array type or tile_shape was - not of xla_shape.Shape type. + TypeError: tile_assignment was not of np.array type. TODO(jmolloy): This concept is nefarious and is not something we really want to expose to users (especially as the @@ -87,14 +80,11 @@ class Sharding(object): """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') - if not isinstance(tile_shape, xla_shape.Shape): - raise TypeError('Tile shape must be of type xla_shape.Shape') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape.message, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) @@ -118,14 +108,8 @@ class Sharding(object): shape = tensor.shape.as_list() if shape[split_dimension] < num_devices: raise ValueError('Split dimension was smaller than the required number ' - 'of splits: shape=%r, dimension=%r, num_devices=%r', - shape, split_dimension, num_devices) - - tile_shape = shape - tile_shape[split_dimension] = int( - math.ceil(tile_shape[split_dimension] / num_devices)) - tile_shape_proto = xla_data_pb2.Shape( - element_type=xla_data_pb2.F32, dimensions=tile_shape) + 'of splits: shape=%r, dimension=%r, num_devices=%r' % + (shape, split_dimension, num_devices)) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices @@ -133,7 +117,6 @@ class Sharding(object): return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) @@ -149,7 +132,6 @@ class Sharding(object): type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto - attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. @@ -194,8 +176,8 @@ def assign_device(tensor, device): return tensor -def tile(tensor, tile_shape, tile_assignment): - Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment): + Sharding.tile(tile_assignment).apply_to_tensor(tensor) return tensor diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index bcfbcc3a22f50c748c388d17fbcd7defd27846d0..267701e9c0e42a21d2cda6238520f6a9692e7e76 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -3,15 +3,15 @@ upper_tabs: - include: /_upper_tabs_left.yaml - include: /api_docs/_upper_tabs_api.yaml # Dropdown menu -- name: Ecosystem - path: /ecosystem +- name: Resources + path: /resources is_default: true menu: - - include: /ecosystem/_menu_toc.yaml + - include: /resources/_menu_toc.yaml lower_tabs: # Subsite tabs other: - - name: Guide + - name: Guide & Tutorials contents: - title: XLA overview path: /xla/overview @@ -27,3 +27,9 @@ upper_tabs: path: /xla/shapes - title: Using AOT compilation path: /xla/tfcompile + - heading: Tutorials + - title: XLA compile API + path: /xla/tutorials/xla_compile + status: experimental + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml index 7934cd11ba22d3f47e172726f54ce51d15eb2cad..858de427119bfcfa82d0b1158776bf269129fd92 100644 --- a/tensorflow/compiler/xla/g3doc/_index.yaml +++ b/tensorflow/compiler/xla/g3doc/_index.yaml @@ -17,7 +17,7 @@ landing_page: - classname: devsite-landing-row-cards items: - heading: XLA - TensorFlow, compiled - image_path: /ecosystem/images/tf-logo-card-16x9.png + image_path: /resources/images/tf-logo-card-16x9.png path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html buttons: - label: Read on Google Developers blog @@ -28,7 +28,7 @@ landing_page: - label: Watch the video path: https://www.youtube.com/watch?v=kAOanJczHA0 - heading: XLA on GitHub - image_path: /ecosystem/images/github-card-16x9.png + image_path: /resources/images/github-card-16x9.png path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla buttons: - label: View on GitHub diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png new file mode 100644 index 0000000000000000000000000000000000000000..00cefe4c7806c1c09dd51499375e720bfb0baac6 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png new file mode 100644 index 0000000000000000000000000000000000000000..6439c6e40272ae6b2954e9d7f3de2df470a2b36d Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png differ diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index ded1e582b24c7a45acc6b61ba9c018fa2a1e7db7..85fa16ccc7f48a3dce840564e79097c9e136767f 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -86,7 +86,7 @@ on uncompilable operator, xla.compile() returns an explicit error. This is useful if you want more predictable behaviors from XLA compilation. Please see -[xla.compile() tutorial Colab](https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb) +[xla.compile() tutorial Colab](./tutorials/xla_compile.ipynb) for how to use it. ### Placing operators on XLA devices @@ -144,7 +144,7 @@ Execute the python script to train the model with XLA and turn on a debugging feature of XLA via an environmental variable that outputs the XLA graph. ```shell -TF_XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md new file mode 100644 index 0000000000000000000000000000000000000000..5e990851af7495ebd4417e44f1d955fcc14dadf1 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md @@ -0,0 +1,159 @@ +# Tiled layout + +*Note: This doc describes how tiled layout is intended to work. Tiling is being +implemented, but this is an early effort and it is currently not even guaranteed +to get an Unimplemented error if one tries to use tiling - it may be just +silently ignored.* + +
![](images/xla_array_layout_figure1.png) + +Figure 1
+ +Figure 1 shows how an array F32[3,5] is laid out in memory with 2x2 tiling. A +shape with this layout is written as F32[3,5]{1,0:(2,2)}, where 1,0 relates to +the physical order of dimensions (minor_to_major field in Layout) while (2,2) +after the colon indicates tiling of the physical dimensions by a 2x2 tile. + +Intuitively tiles are laid out to cover the shape and then within each tile, +elements are then laid out without tiling, as in the example above, where the +right part of the example shows the layout in memory, including the white +padding elements that are added in order to have complete 2x2 tiles even though +the original array bounds are not even. + +The extra elements in the padding are not required to contain any particular +value. + +## Linear index formulas for tiling given a shape and a tile + +Without tiling, an element e=(en, en-1, ... , +e1) in an array with array bounds d=(dn, dn-1, +... , d1) (d1 is the most minor dimension) is laid out by major to +minor order at position: + +   linear_index(e, d) \ += linear_index((en, en-1, ... , e1), +(dn, dn-1, ... , d1)) \ += endn-1...d1 + +en-1dn-2...d1 + ... + e1 + +For simplicity of notation in this document we assume a tile has the same number +of dimensions as the array. In XLA's implementation of tiling, this is +generalized to tilings with fewer dimensions by leaving the initial most-major +dimensions unchanged and applying the tiling only to the most minor dimensions, +so that the tiling that is specified mentions a suffix of the physical +dimensions of the shape being tiled. + +When tiling of size (tn, tn-1, ... , t1) is +used, an element in the array with indices (en, en-1, ... +, e1) is mapped to this position in the final layout: + +   linear_index_with_tile(e, d, t) \ += linear_index((⌊e/t⌋, e mod t), (⌈d/t⌉, t))     (arithmetic is +elementwise, (a,b) is concatenation) \ += linear_index((⌊en/tn⌋, ... , +⌊e1/t1⌋, en mod tn, ... , +e1 mod t1), (⌈dn/tn⌉, ... , +⌈d1/t1⌉, tn, tn-1, ... , +t1)) \ += linear_index((⌊en/tn⌋, ... , +⌊e1/t1⌋), (⌈dn/tn⌉, ... , +⌈d1/t1⌉))∙tntn-1...t1 + +linear_index((en mod tn, ... , e1 mod +t1), (tn, tn-1, ... , t1)) + +The layout can be thought of as having two parts: +(⌊en/tn⌋, ... , ⌊e1/t1⌋), which +corresponds to a tile index in an array of tiles of size +(⌈dn/tn⌉, ... , ⌈d1/t1⌉), and +(en mod tn, ... , e1 mod t1), which +corresponds to a within-tile index. The ceil function appears in +⌈di/ti⌉ because if tiles overrun the bounds of the larger +array, padding is inserted as in Figure 1. Both the tiles and elements within +tiles are laid out recursively without tiling. + +For the example in Figure 1, element (2,3) has tile index (1,1), and within-tile +index (0,1), for a combined coordinate vector of (1, 1, 0, 1). The tile indices +have bounds (2, 3) and the tile itself is (2, 2) for a combined vector of (2, 3, +2, 2). The linear index with tile for the element with index (2, 3) in the +logical shape is then + +   linear_index_with_tile((2,3), (3,5), (2,2)) \ += linear_index((1,1,0,1), (2,3,2,2)) \ += linear_index((1,1), (2,3)) ∙ 2 ∙ 2 + linear_index((0,1), (2,2)) \ += (1 ∙ 3 + 1) ∙ 2 ∙ 2 + (0 ∙ 2 + 1) \ += 17. + +# Tiling as pad-reshape-transpose + +Tiling-based layout operates as follows: \ +Consider an array of dimensions (dn, dn-1, ... , d1) (d1 +is the most minor dimension). When it’s laid out with tiling of size +(tn, tn-1, ... , t1) (t1 is the most +minor dimension), that tiling can be described in terms of pad-reshape-transpose +in the following way. + +1. The array is padded to (⌈dn/tn⌉∙tn, ... , + ⌈d1/t1⌉∙t1). +2. Each dimension i is broken into (⌈di/ti⌉, + ti), i.e. the array is reshaped to \ +     (⌈dn/tn⌉, tn, ... , + ⌈d1/t1⌉, t1). \ + There is no physical layout change in this reshape by itself, so this + reshape is a bitcast. If one is not explicitly thinking of a tiling, this + reshape could express any shape with the same number of elements as the + padded shape - the example here is of how to express a tile in this way. +3. A transpose happens by moving tn, ... , t1 to the most + minor dimensions while keeping their relative order, so that the order of + dimensions from most major to most minor becomes \ +     (⌈dn/tn⌉, ... , + ⌈d1/t1⌉, tn, ... , t1). + +The final shape has the prefix \ +    (⌈dn/tn⌉, ... , +⌈d1/t1⌉), which describes the number of tiles in each +dimension. An element in the array (en, ... , e1) is +mapped to this element in the final shape: \ +    (⌊en/tn⌋, ... , +⌊e0/t0⌋, en mod tn, ... , +e1 mod t1). It is easy to see that the linear index of the +element follows the formula above as expected. + +# Repeated tiling + +XLA's tiling becomes even more flexible by applying it repeatedly. + +
![](images/xla_array_layout_figure2.png) + +Figure 2
+ +Figure 2 shows how an array of size 4x8 is tiled by two levels of tiling (first +2x4 then 2x1). We represent this repeated tiling as (2,4)(2,1). Each color +indicates a 2x4 tile and each red border box is a 2x1 tile. The numbers +indicates the linear index in memory of that element in the tiled format. This +format matches the format used for BF16 on TPU, except that the initial tile is +bigger, namely the tiling is (8,128)(2,1), where the purpose of the second +tiling by 2x1 is to collect together two 16 bit values to form one 32 bit value +in a way that aligns with the architecture of a TPU. + +Note that a second or later tile can refer to both the minor within-tile +dimensions, which just rearranges data within the tile, as in this example with +(8,128)(2,1), but can also refer to the major cross-tile dimensions from the +prior tiling. + +# Combining dimensions using tiles + +XLA's tiling also supports combining dimensions. For example, it can combine +dimensions in F32[2,7,8,11,10]{4,3,2,1,0} into F32[112,110]{1,0} first before +tiling it with (2,3). The tile used is (∗,∗,2,∗,3). Here an +asterisk in a tile implies taking that dimension and combining it with the next +more minor dimension. Multiple adjacent dimensions can be subsumed together into +one dimension. A subsumed dimension is represented by a tile value of -1 in that +dimension of the tile, which is not otherwise valid in a tile as a dimension +size. + +More precisely, if dimension i of the shape is eliminated via an asterisk in the +tile, then before the prior definition of tiling is applied, that dimension is +removed from both the shape being tiled and the tile vector, and what was +dimension i-1 of the shape has its array bound increased from di-1 to +didi-1. This step is repeated for each asterisk in the +tile vector. diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 73a9db75f6bf090bba5c3534f14d8ebfa421b5bb..d888b1f23f36f33ef94ef0e22374e0c796e47a89 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -13,6 +13,22 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AfterAll + +See also +[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +AfterAll takes a variadic number of tokens and produces a single token. Tokens +are primitive types which can be threaded between side-effecting operations to +enforce ordering. `AfterAll` can be used as a join of tokens for ordering a +operation after a set operations. + + `AfterAll(operands)` + +Arguments | Type | Semantics +---------- | ------- | ------------------------- +`operands` | `XlaOp` | variadic number of tokens + ## AllToAll See also @@ -402,6 +418,33 @@ then v12 == f32[8x3] {{10, 11, 12}, ``` +## CollectivePermute + +See also +[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +CollectivePermute is a collective operation that sends and receives data cross +replicas. + + `CollectivePermute(operand, source_target_pairs)` + +| Arguments | Type | Semantics | +| --------------------- | ----------------------- | -------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `source_target_pairs` | `` vector | A list of | +: : : (source_replica_id, : +: : : target_replica_id) pairs. : +: : : For each pair, the operand : +: : : is sent from source : +: : : replica to target replica. : + +Note that there are the following restrictions on the `source_target_pair`: + +- Any two pairs should not have the same target replica id, and they should + not have the same source replica id. +- If a replica id is not a target in any pair, then the output on that replica + is a tensor consists of 0(s) with the same shape as the input. + ## Concatenate See also @@ -1423,10 +1466,11 @@ Builds a constant literal on device rather than a potentially large host transfer. Creates a rank 1 array of values starting at zero and incrementing by one. -Arguments | Type | Semantics ---------- | --------------- | ------------------------------------ -`type` | `PrimitiveType` | type U -`size` | `int64` | The number of elements in the array. +Arguments | Type | Semantics +---------------- | --------------- | ------------------------------------ +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the array. +`iota_dimension` | `int64` | The dimension to increment along. ## Map @@ -1780,8 +1824,9 @@ XlaBuilder builder(client_, "reduce_window_2x3"); auto shape = ShapeUtil::MakeShape(F32, {4, 6}); auto input = builder.Parameter(0, shape, "input"); builder.ReduceWindow( - input, *max, + input, /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)), + *max, /*window_dimensions=*/{2, 3}, /*window_stride_dimensions=*/{2, 3}, Padding::kValid); diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb index a83e3f78598e7c0afaada43b8ae1ba71ad4839d6..2a83092805be5efdd7b9ab54449b2bcc6a2ec481 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -1,25 +1,38 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "The XLA compile API", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "f4TSNCvpENrW" }, + "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { "cellView": "form", - "colab": {}, "colab_type": "code", - "id": "vamNSA0vEP-m" + "id": "vamNSA0vEP-m", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -32,139 +45,84 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ] - }, - { - "cell_type": "code", + ], "execution_count": 0, - "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", - "id": "xD_ydfejEV7H" - }, - "outputs": [], - "source": [ - "#@title MIT License\n", - "#\n", - "# Copyright (c) 2017 François Chollet\n", - "#\n", - "# Permission is hereby granted, free of charge, to any person obtaining a\n", - "# copy of this software and associated documentation files (the \"Software\"),\n", - "# to deal in the Software without restriction, including without limitation\n", - "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", - "# and/or sell copies of the Software, and to permit persons to whom the\n", - "# Software is furnished to do so, subject to the following conditions:\n", - "#\n", - "# The above copyright notice and this permission notice shall be included in\n", - "# all copies or substantial portions of the Software.\n", - "#\n", - "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", - "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", - "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", - "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", - "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", - "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", - "# DEALINGS IN THE SOFTWARE." - ] + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e1oSi4lHFt3z" }, + "cell_type": "markdown", "source": [ - "# Welcome to `xla.compile()` tutorial" + "# The XLA compile API" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "b7noD9NjFRL-" }, + "cell_type": "markdown", "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/jit#turning_on_jit_compilation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "v9YbsuLZaBXy" }, + "cell_type": "markdown", "source": [ - "xla.compile() is a new experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/).\n", "\n", - "Please run all code blocks in order." + "\n", + "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "45kUPj5ZFrRa" - }, - "outputs": [], - "source": [ - "import tensorflow as tf" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "9NMQFjroSMns" + "id": "45kUPj5ZFrRa", + "colab": {} }, - "source": [ - "Imports XLA library, which includes xla.compile() experimental API." - ] - }, - { "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "-Uggy03rSGJm" - }, - "outputs": [], "source": [ + "import tensorflow as tf\n", + "\n", "from tensorflow.contrib.compiler import xla" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GZVNiRmTDV-5" }, + "cell_type": "markdown", "source": [ - "Define some necessary constants and prepare MNIST dataset." + "Define some necessary constants and prepare the MNIST dataset." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "f37TSEGvGX4_" + "id": "f37TSEGvGX4_", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Size of each input image, 28 x 28 pixels\n", "IMAGE_SIZE = 28 * 28\n", @@ -174,17 +132,17 @@ "TRAIN_BATCH_SIZE = 100\n", "# Number of training steps to run\n", "TRAIN_STEPS = 1000" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "TiVXchblG5hK" + "id": "TiVXchblG5hK", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Loads MNIST dataset.\n", "train, test = tf.keras.datasets.mnist.load_data()\n", @@ -195,16 +153,18 @@ "images, labels = iterator.get_next()\n", "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x_ZehpZP-SfS" }, + "cell_type": "markdown", "source": [ - "## Defines build_mnist_model function to construct model\n", + "# Define the model constructing function\n", "\n", "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", "\n", @@ -212,14 +172,12 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ZbhJl_WvGa3g" + "id": "ZbhJl_WvGa3g", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def build_mnist_model(x, y_):\n", " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", @@ -228,47 +186,41 @@ " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", "\n", " return y, train_step" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7Jh3lyQHDfM9" }, - "source": [ - "## Uses xla.compile with build_mnist_model function to enable XLA" - ] - }, - { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "EtDwez_1gjzv" - }, "source": [ - "Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA." + "# Enable XLA\n", + "\n", + "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "kYpCXCdRHNuN" + "id": "kYpCXCdRHNuN", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4giQh62IrZGF" }, + "cell_type": "markdown", "source": [ "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", "\n", @@ -293,62 +245,62 @@ ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "TPGas4jjFLZl" }, + "cell_type": "markdown", "source": [ "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EZD1m_n1DxAF" }, + "cell_type": "markdown", "source": [ - "## Trains and tests the model" + "# Train and test the model" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "qe28bAHNHUG2" + "id": "qe28bAHNHUG2", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Creates session and initialize all variables.\n", "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", "sess = tf.Session()\n", "sess.run(tf.global_variables_initializer())" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qgsKmz3n2UiW" }, + "cell_type": "markdown", "source": [ - "Following code block trains model.\n", - "\n", - "Note that evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "_GxF6jTRHVuA" + "id": "_GxF6jTRHVuA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132" }, - "outputs": [], + "cell_type": "code", "source": [ "# Feeds training dataset\n", "sess.run(iterator.make_initializer(train_ds))\n", @@ -356,18 +308,31 @@ "# Runs TRAIN_STEPS steps\n", "for i in range(TRAIN_STEPS):\n", " sess.run(y)\n", + "\n", "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Model trained for 1000 steps.\n" + ], + "name": "stdout" + } ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "dHlQlRSRHXD1" + "id": "dHlQlRSRHXD1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093" }, - "outputs": [], + "cell_type": "code", "source": [ "# Tests trained model\n", "\n", @@ -378,35 +343,31 @@ "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prediction accuracy after training: 0.91\n" + ], + "name": "stdout" + } ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "ynJQIuzjHYOb" + "id": "ynJQIuzjHYOb", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# Cleans up session\n", "sess.close()" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "xla.compile() Tutorial", - "provenance": [], - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 2", - "name": "python2" + ], + "execution_count": 0, + "outputs": [] } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 458bdaf2f89819d2fbd8518150d11b42ce9f9c6e..d76f61eb62c0fc89d6bc3ca2033e8c7170f30e78 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 2398470dd49955f154dcb32edae6f3b9f961f89d..dbb81381acde645f08639737b6e7b6f6ad971f9b 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -460,6 +460,13 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { } hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + for (Tile tile : layout.tiles()) { + for (int64 tile_dim : tile.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(tile_dim)); + } + } + hash_value = Hash64Combine(hash_value, layout.element_size_in_bits()); + return hash_value; } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6e0390763da15167b85597462f3e21b8e1eaf732..6c298e57252449ce3f1f9055436e918f2d9f17f1 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cb00a0ab16df851ccbd4bba960b92ea83157867d..8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -62,6 +63,14 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be +// able to transparently access the raw 16-bit value contained within. +template +T GetRawValue(T val) { + return val; +} +uint16 GetRawValue(Eigen::half val) { return val.x; } + } // namespace LiteralBase::~LiteralBase() {} @@ -283,16 +292,17 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } - if (ShapeUtil::HasPrimitiveType(proto.shape(), OPAQUE)) { + Shape shape(proto.shape()); + if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) { return InvalidArgument("Literal shape cannot include OPAQUE sub-shape"); } - if (!LayoutUtil::HasLayout(proto.shape())) { + if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("LiteralProto has no layout"); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - Literal literal(proto.shape()); + Literal literal(shape); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -1012,166 +1022,143 @@ void LiteralBase::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); +string ShapeToString(bool print_layout, const Shape& shape) { + return print_layout ? ShapeUtil::HumanStringWithLayout(shape) + : ShapeUtil::HumanString(shape); +} - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces); - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); +void TupleToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); + } + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); + pieces->push_back("\n)"); +} + +void SparseArrayToStringHelper(const LiteralBase& literal, + const Shape& subshape, bool print_layout, + std::vector* pieces) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); } - pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { - pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); } - pieces->push_back("}"); - return; + pieces->push_back(literal.GetSparseElementAsString(i)); } + pieces->push_back("}"); +} - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = [&](absl::Span indices) -> string { - PrimitiveType element_type = subshape.element_type(); - // We display predicates as 0s and 1s so that the string is more dense. - string elem = element_type == PRED - ? literal.Get(indices, shape_index) ? "1" : "0" - : literal.GetAsString(indices, shape_index); - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + elem; - }; +void DenseArrayToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + int64 rank = ShapeUtil::Rank(subshape); + + std::function dimensions, std::vector*)> + to_string_recursive = [&](absl::Span dimensions, + std::vector* accum_indices) { + // dimensions.size() decreases by 1 at each recursive call, + // and accum_indices->size() increases by 1. + // Their sum is equal to the rank of the tensor. + CHECK_EQ(rank, dimensions.size() + accum_indices->size()); + + auto brace_to_string = [&](string brace) -> string { + // Handle 1D tensor + if (rank == 1) { + return brace; + } + // Handle the innermost tensor of a 2D+ tensor. + if (dimensions.size() == 1 && brace == "{") { + return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " "); + } + if (dimensions.size() == 1 && brace == "}") { + return StrCat(dimensions[0] <= 1 ? "" : " ", brace); + } + // Handle the non-innermost tensors of a 2D+ tensor. + if (brace == "{") { + if (rank > 3 && !accum_indices->empty() && + accum_indices->size() < rank) { + int index = accum_indices->size() - 1; + int value = accum_indices->back(); + return StrCat(brace, " /*i", index, "=", value, "*/\n"); + } + return StrCat(brace, "\n"); + } + return StrCat("\n", brace); + }; - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); + if (dimensions.empty()) { + // Display predicates as 0s and 1s so that the string is more dense. + string elem; + if (subshape.element_type() == PRED && rank > 0) { + elem = literal.Get(*accum_indices, shape_index) ? "1" : "0"; + } else { + elem = literal.GetAsString(*accum_indices, shape_index); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + pieces->push_back(elem); + } else { + pieces->push_back(brace_to_string("{")); + for (int i = 0; i < dimensions[0]; ++i) { + std::vector cloned_indices(*accum_indices); + cloned_indices.push_back(i); + to_string_recursive(dimensions.subspan(1), &cloned_indices); + if (i < dimensions[0] - 1) { + pieces->push_back(","); + pieces->push_back(dimensions.size() > 1 ? "\n" : " "); } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); + pieces->push_back(brace_to_string("}")); } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); + }; + + if (rank > 1) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" "); + } + std::vector indices = {}; + std::vector dimensions(subshape.dimensions().begin(), + subshape.dimensions().end()); + to_string_recursive(dimensions, &indices); +} + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + if (ShapeUtil::IsTuple(subshape)) { + TupleToStringHelper(literal, shape_index, print_layout, pieces); + } else if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + } else if (LayoutUtil::IsSparseArray(subshape)) { + SparseArrayToStringHelper(literal, subshape, print_layout, pieces); } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](absl::Span indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); + CHECK(LayoutUtil::IsDenseArray(subshape)); + DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); } } @@ -1228,16 +1215,32 @@ Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { } template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) && + !std::is_same::value), Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { - return absl::bit_cast(src); + return absl::bit_cast(GetRawValue(src)); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); } +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) && + std::is_same::value), + Literal>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly + // cast to unsigned short and then use raw_uint16_to_half. + auto converter = [](NativeSrcT src) { + return Eigen::half_impl::raw_uint16_to_half( + absl::bit_cast(GetRawValue(src))); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + // This template specialization is here to make the compiler happy. bit_cast has // a static check that the types are the same size. This specialization should // never be used because the source and destination types are checked for @@ -1792,7 +1795,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); + *proto->mutable_shape() = subshape().ToProto(); switch (subshape().element_type()) { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); @@ -1898,8 +1901,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + Shape shape(proto.shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(shape)); + TF_RET_CHECK(ShapeUtil::Equal(shape, subshape())); if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index e791048b4d9f5dcf877e05e3b5cf16eb37c07dbc..fa9a71af4ceb998a7a289443cbef70eb52cb1a11 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -301,7 +301,7 @@ class LiteralBase { // // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero - // initialization, then reinitialization. Conside if a call to + // initialization, then reinitialization. Consider if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static Literal CreateFromShape(const Shape& shape); diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 8cec37897a94472d61d2346cf4cab03c45033800..49363ad802ddb9520f89b53257216bc7ddaf8ff5 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -150,12 +150,58 @@ TEST_F(LiteralUtilTest, R3ToString) { const auto literal = LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { -{ { 1 }, - { 2 } }, -{ { 3 }, - { 4 } }, -{ { 5 }, - { 6 } } +{ + {1}, + {2} +}, +{ + {3}, + {4} +}, +{ + {5}, + {6} +} +})"; + EXPECT_EQ(expected, literal.ToString()); +} + +TEST_F(LiteralUtilTest, R6ToString) { + const auto literal = + LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2}); + const string expected = R"(s32[2,2,1,1,1,2] { +{ /*i0=0*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +}, +{ /*i0=1*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +} })"; EXPECT_EQ(expected, literal.ToString()); } @@ -190,12 +236,16 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[2,3,2] { -{ { 1, 2 }, +{ + { 1, 2 }, { 3, 4 }, - { 5, 6 } }, -{ { 7, 8 }, + { 5, 6 } +}, +{ + { 7, 8 }, { 9, 10 }, - { 11, 12 } } + { 11, 12 } +} })"; EXPECT_EQ(expected, result); } @@ -247,18 +297,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - }, - { /*i1=1*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +}, +{ /*i1=1*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +} +} })"; EXPECT_EQ(expected, result); } @@ -268,30 +318,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2, 3}, - {4, 5, 6}, - {7, 8, 9} - }, - { /*i1=1*/ - {11, 12, 13}, - {14, 15, 16}, - {17, 18, 19} - } - }, - { /*i0=1*/ - { /*i1=0*/ - {101, 102, 103}, - {104, 105, 106}, - {107, 108, 109} - }, - { /*i1=1*/ - {201, 202, 203}, - {204, 205, 206}, - {207, 208, 209} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } +}, +{ /*i1=1*/ + { 11, 12, 13 }, + { 14, 15, 16 }, + { 17, 18, 19 } +} +}, +{ /*i0=1*/ +{ /*i1=0*/ + { 101, 102, 103 }, + { 104, 105, 106 }, + { 107, 108, 109 } +}, +{ /*i1=1*/ + { 201, 202, 203 }, + { 204, 205, 206 }, + { 207, 208, 209 } +} +} })"; EXPECT_EQ(expected, result); } @@ -1327,13 +1377,26 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { absl::StrContains(status.error_message(), "bit widths are different")); } +// Sets the layout of the given ShapeProto to the default. +void SetDefaultLayoutOnProto(ShapeProto* shape_proto) { + CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type())); + shape_proto->mutable_layout()->set_format(DENSE); + auto* minor_to_major = + shape_proto->mutable_layout()->mutable_minor_to_major(); + minor_to_major->Resize(shape_proto->dimensions_size(), 0); + const int64 size = minor_to_major->size(); + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, size - 1 - i); + } +} + TEST_F(LiteralUtilTest, CopyFromProto_Bool) { LiteralProto p; p.mutable_shape()->set_element_type(PRED); for (int len = 0; len < 25; ++len) { p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(len); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_preds(); for (int i = 0; i < len; ++i) { p.add_preds((i % 2) == (len % 2)); @@ -1359,7 +1422,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { EXPECT_EQ(4, m.data().size()); LiteralProto p = m.ToProto(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); + EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape()))); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); EXPECT_EQ(d[0], 0); @@ -1382,7 +1445,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); @@ -1404,7 +1467,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) { p.mutable_shape()->set_element_type(U16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_u16s(); p.set_u16s(uint16_vals, 8); TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); @@ -1537,9 +1600,9 @@ TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nested_tuple = LiteralUtil::MakeTuple( {&tuple_elements[0], &tuple_elements[1], &nil_literal}); - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); std::vector elements = nested_tuple.DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1590,7 +1653,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); for (const Literal& element : elements) { - EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape())); } } @@ -1706,7 +1769,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { TEST_F(LiteralUtilTest, InvalidProtoNoValues) { // Proto contains a shape, but no values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), @@ -1727,7 +1790,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { // Proto contains values in wrong container. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); proto.add_preds(false); proto.add_preds(true); proto.add_preds(false); @@ -1740,7 +1803,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { // Proto contains too few values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto(); proto.add_f32s(1.0); proto.add_f32s(2.0); proto.add_f32s(3.0); @@ -1753,7 +1816,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { // Proto contains too many values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto(); proto.add_s32s(42); proto.add_s32s(-10); proto.add_s32s(100); @@ -1766,8 +1829,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { // Proto shape missing layout. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); - LayoutUtil::ClearLayout(proto.mutable_shape()); + *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto(); + proto.mutable_shape()->clear_layout(); proto.add_preds(true); proto.add_preds(false); proto.add_preds(true); @@ -1780,11 +1843,13 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { // Proto has the too few tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); @@ -1796,19 +1861,21 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { // Proto has the too many tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); LiteralProto* element1 = proto.add_tuple_literals(); *element1->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 1); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto(); element1->add_f32s(42.0); LiteralProto* element2 = proto.add_tuple_literals(); - *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto(); element2->add_f32s(123.0); Status status = Literal::CreateFromProto(proto).status(); diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 40481331b6992103e10e3fe635a030d3bdffebc9..5b568888d14f21c1330556d017eafba6c8dd2228 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -13,15 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from an environtment variable, or a file named by the environment -// variable. +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or a file named by the +// environment variable. #include #include #include +#include +#include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -32,7 +37,6 @@ limitations under the License. namespace xla { -static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed @@ -42,12 +46,20 @@ static const char kWS[] = " \t\r\n"; // whitespace // constructor/destructor collisions with other "private" types // in the same named namespace. namespace { + +// Functor which deletes objects by calling `free`. Necessary to free strdup'ed +// strings created by AppendToEnvArgv. +struct FreeDeleter { + void operator()(char* ptr) { free(ptr); } +}; + struct EnvArgv { EnvArgv() : initialized(false), argc(0) {} bool initialized; // whether the other fields have been set. int argc; // elements used in argv[] std::vector argv; // flag arguments parsed from environment string. - std::vector argv_save; // saved values from argv[] to avoid leaks + // saved values from argv[] to avoid leaks + std::vector> argv_save; }; } // anonymous namespace @@ -63,7 +75,7 @@ static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, string s = string(s0, s0len) + string(s1, s1len); char* str = strdup(s.c_str()); a->argv.push_back(str); - a->argv_save.push_back(str); + a->argv_save.emplace_back(str); a->argc++; } } @@ -127,14 +139,14 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { } } -// Call ParseArgvFromString(..., a) on a string derived from the setting of an -// environment variable kEnvVar, or a file it points to. -static void SetArgvFromEnv(EnvArgv* a) { +// Call ParseArgvFromString(..., a) on a string derived from the setting of the +// environment variable `envvar`, or a file it points to. +static void SetArgvFromEnv(absl::string_view envvar, EnvArgv* a) { if (!a->initialized) { static const char kDummyArgv[] = ""; AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, a); // dummy argv[0] - const char* env = getenv(kEnvVar); + const char* env = getenv(string(envvar).c_str()); if (env == nullptr || env[0] == '\0') { // nothing } else if (env[strspn(env, kWS)] == '-') { // flags in env var value @@ -157,48 +169,66 @@ static void SetArgvFromEnv(EnvArgv* a) { } } -// The simulated argv[] parsed from the environment. -static EnvArgv* env_argv; +// The simulated argv[] parsed from the environment, one for each different +// environment variable we've seen. +static std::unordered_map& EnvArgvs() { + static auto* env_argvs = new std::unordered_map(); + return *env_argvs; +} -// Used to protect accesses to env_argv. +// Used to protect accesses to env_argvs. static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); -// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized -// flags passed in from the environment. -bool ParseFlagsFromEnv(const std::vector& flag_list) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - SetArgvFromEnv(env_argv); // a no-op if already initialized +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list) { + tensorflow::mutex_lock lock(env_argv_mu); + auto* env_argv = &EnvArgvs()[string(envvar)]; + SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); - env_argv_mu.unlock(); + + // There's always at least one unparsed argc, namely the fake argv[0]. + if (result && env_argv->argc != 1) { + // Skip the first argv, which is the fake argv[0]. + auto unknown_flags = absl::MakeSpan(env_argv->argv); + unknown_flags.remove_prefix(1); + + // Some flags are set on XLA_FLAGS, others on TF_XLA_FLAGS. If we find an + // unrecognized flag, suggest the alternative. + string alternate_envvar; + if (envvar == "TF_XLA_FLAGS") { + alternate_envvar = "XLA_FLAGS"; + } else if (envvar == "XLA_FLAGS") { + alternate_envvar = "TF_XLA_FLAGS"; + } + string did_you_mean; + if (!alternate_envvar.empty()) { + did_you_mean = absl::StrFormat( + "\nPerhaps you meant to specify these on the %s envvar?", + alternate_envvar); + } + + LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; + return false; + } return result; } // Testing only. -// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv() -// will parse the environment variable (or the file it points to) anew, and set -// *pargc, and *pargv to point to the internal locations of the argc and argv -// constructed from the environment. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - if (!env_argv->argv_save.empty()) { - for (int i = 0; env_argv->argv_save[i] != nullptr; i++) { - free(env_argv->argv_save[i]); - } - } - env_argv->initialized = false; - env_argv->argc = 0; - env_argv->argv.clear(); - env_argv->argv_save.clear(); - env_argv_mu.unlock(); - *pargc = &env_argv->argc; - *pargv = &env_argv->argv; +// +// Resets the env_argv struct so that subsequent calls to +// ParseFlagsFromEnvAndDieIfUnknown() will parse the environment variable (or +// the file it points to) anew, and set *pargc, and *pargv to point to the +// internal locations of the argc and argv constructed from the environment. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv) { + tensorflow::mutex_lock lock(env_argv_mu); + EnvArgvs().erase(string(envvar)); + auto& env_argv = EnvArgvs()[string(envvar)]; + *pargc = &env_argv.argc; + *pargv = &env_argv.argv; } } // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h index fe86ee687f8482aaffc2ebe04a723d9a22f2cce6..76940a4299ac50138222333ff250a264cc941288 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.h +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -16,48 +16,58 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ #define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from the environtment variable TF_XLA_FLAGS, or (if the first +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or (if the first // non-whitespace in the variable value is not '-'), a file named by that -// environment variable. The accepted syntax is that flags arguments are of -// the form --flag=value or (for boolean flags) --flag, and are whitespace -// separated. The may be one of: -// - -// in which case the effective value is the string itself -// - in which case the effective value is the -// string with the single-quotes removed -// - in which case the effective value if the -// string with the double-quotes removed, and escaped sequences of -// replaced by . +// environment variable. +// +// The accepted syntax is that flags arguments are of the form --flag=value or +// (for boolean flags) --flag, and are whitespace separated. The may be +// one of: +// +// - +// in which case the effective value is the string itself +// - in which case the effective value is the +// string with the single-quotes removed +// - in which case the effective value if the +// string with the double-quotes removed, and escaped sequences of +// replaced by . // // Flags values inconsistent with the type of the flag will be rejected by the // flag parser. // // Examples: -// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" // -// TF_XLA_FLAGS=/tmp/flagfile +// - TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" +// - TF_XLA_FLAGS=/tmp/flagfile +// // where /tmp/flagfile might contain -// --some_flag="This is a string containing a \" and a '." -// --another_flag=wombats +// +// --some_flag="This is a string containing a \" and a '." +// --another_flag=wombats #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" namespace xla { -// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet -// unrecognized flags passed in from the environment, and return its -// return value. -bool ParseFlagsFromEnv(const std::vector& flag_list); +// Calls tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet +// unrecognized flags passed in the environment variable `envvar`, and returns +// its return value. +// +// Raises a fatal error if any flags in `envvar` were not recognized. +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list); // Used only for testing. Not to be used by clients. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv); } // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc index edd6538402d6ceee292ca6a265f490be9709d3ae..3465552ebbf52140fb954b247d99d3c6afe7fcde 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -37,20 +37,7 @@ static void TestParseFlagsFromEnv(const char* msg) { // Initialize module under test. int* pargc; std::vector* pargv; - ResetFlagsFromEnvForTesting(&pargc, &pargv); - - // Ensure that environment variable can be parsed when - // no flags are expected. - std::vector empty_flag_list; - bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); - CHECK(parsed_ok) << msg; - const std::vector& argv_first = *pargv; - CHECK_NE(argv_first[0], nullptr) << msg; - int i = 0; - while (argv_first[i] != nullptr) { - i++; - } - CHECK_EQ(i, *pargc) << msg; + ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv); // Check that actual flags can be parsed. bool simple = false; @@ -65,7 +52,7 @@ static void TestParseFlagsFromEnv(const char* msg) { tensorflow::Flag("single_quoted", &single_quoted, ""), tensorflow::Flag("double_quoted", &double_quoted, ""), }; - parsed_ok = ParseFlagsFromEnv(flag_list); + bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); CHECK_EQ(*pargc, 1) << msg; const std::vector& argv_second = *pargv; CHECK_NE(argv_second[0], nullptr) << msg; @@ -171,7 +158,8 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::ParseFlagsFromEnv(flag_list); + bool parse_ok = + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index b507a2ef79f1d7e9ae632744675dddf574490805..ac342bf40fbc0052acbb09a346b9d062561ed06b 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -40,16 +40,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, namespace { -string SanitizeFilename(const string& file_name) { - string safe_file_name = file_name; - for (char& c : safe_file_name) { - if (c == '/' || c == '\\') { - c = '_'; - } - } - return safe_file_name; -} - std::pair>*> GetDirectoryExpanders() { static auto* mutex = new tensorflow::mutex; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 4d2a37cfac3e0e89d189f168031e5db44ca5d410..6e2ee866321a070d55a7221c7c68024ceaa93448 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -148,14 +148,19 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout) { + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number) { LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " + << replica_number << "/" << device_ordinal; StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, relaid); + return ToBuffer(client, device_ordinal, relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument); + return ToBuffer(client, device_ordinal, argument); }(); TF_RETURN_IF_ERROR(buf.status()); return new LocalShapedBuffer(std::move(buf).ValueOrDie()); @@ -312,67 +317,127 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr CompiledLocalComputation::Execute( absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); + StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + StatusOr result_buffer_status; + if (!device_ordinal_status.ok()) { + result_buffer_status = device_ordinal_status.status(); + } else { + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(1, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); - VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + result_buffer_status = executable_->Run(argument_buffers, options); + } + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} + +StatusOr CompiledLocalComputation::ExecutePerReplica( + absl::Span> argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + const int num_replicas = GetReplicaCount(); + + if (argument_handles.size() != num_replicas) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_replicas); + } + + VLOG(1) << "Executing with " << num_replicas << " replicas."; // Each replica populates a StatusOr result, but only the output value of // replica zero is returned. - std::vector> results(GetReplicaCount()); - { + std::vector> results(num_replicas); + auto execute = [this, client, num_replicas, &argument_handles, + &results](int replica) { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles[replica].size()); + for (auto& handle : argument_handles[replica]) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(num_replicas, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + + results[replica] = std::move(result_buffer_status); + }; + + if (num_replicas == 1) { + // Fast-path if there is only one replica — run the computation on the + // current thread. + execute(0); + } else { + // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - GetReplicaCount()); - - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule( - [this, client, replica, &argument_handles, &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - - results[replica] = std::move(result_buffer_status); - }); + num_replicas - 1); + + for (int replica = 0; replica < num_replicas - 1; ++replica) { + pool.Schedule([&execute, replica] { execute(replica); }); } + execute(num_replicas - 1); } - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - const auto& statusor = results[replica]; + std::vector wrapped_results(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) { + auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", replica, statusor.status().ToString()); } + wrapped_results[replica] = + new LocalShapedBuffer(std::move(statusor).ValueOrDie()); } - return new LocalShapedBuffer(std::move(results[0]).ValueOrDie()); + return new LocalShapedBufferTuple(std::move(wrapped_results)); } static StatusOr GetReturnValueShape(const XlaComputation& computation) { @@ -487,12 +552,13 @@ StatusOr LocalComputation::CompileForXrt( xrt::XLAComputation c; auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); + ProgramShape shapes; for (auto& shape : argument_shapes) { - *shapes->add_parameters() = shape; + *shapes.add_parameters() = shape; } - TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(shapes); + TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); + LayoutUtil::SetToDefaultLayout(&shapes); + *config->mutable_program_shape() = shapes.ToProto(); auto snapshot = computation().Snapshot().ValueOrDie(); *c.mutable_hlo_snapshot() = *snapshot; @@ -584,9 +650,9 @@ LocalOp LocalComputationBuilder::Broadcast( } LocalOp LocalComputationBuilder::BroadcastInDim( - const LocalOp& operand, const Shape& shape, + const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { - return xla::BroadcastInDim(operand.op(), shape, broadcast_dimensions); + return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 9e617c48bdc5ae4b37c1a1db9a1876bb4c0a6f0d..149e44570df5c6a3df88bbe2ffa779be47842d82 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -71,7 +71,8 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); StatusOr ToLiteral() const; @@ -175,6 +176,12 @@ class CompiledLocalComputation { StatusOr Execute( absl::Span argument_handles); + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( + absl::Span > argument_handles); + private: std::unique_ptr executable_; }; @@ -282,7 +289,8 @@ class LocalComputationBuilder { LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); - LocalOp BroadcastInDim(const LocalOp& operand, const Shape& shape, + LocalOp BroadcastInDim(const LocalOp& operand, + absl::Span out_dim_sizes, absl::Span broadcast_dimensions); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index feabfdb889ca055550c5d1e1c05ca47c1b0bd166..d23d693c1e5bde43b52959e4397aa311268411bb 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -363,6 +363,37 @@ tensorflow::ImportNumpy(); $1 = temps; } +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + std::vector vec; + const int vec_size = PySequence_Size(o); + vec.reserve(vec_size); + for (int j = 0; j < vec_size; ++j) { + PyObject* vec_elt = PySequence_GetItem(o, j); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + Py_DECREF(vec_elt); + Py_DECREF(o); + SWIG_fail; + } + vec.push_back(lsbp); + Py_DECREF(vec_elt); + } + temps.push_back(vec); + Py_DECREF(o); + } + $1 = temps; +} + %typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { @@ -921,22 +952,22 @@ tensorflow::ImportNumpy(); $1 = NULL; } else { if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { - build_options.set_generate_hlo_graph(std::move(s)); + build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { - build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { - build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { - build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s)); })) { return nullptr; } @@ -950,7 +981,7 @@ tensorflow::ImportNumpy(); PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); SWIG_fail; } - build_options.set_hlo_profile(o == Py_True); + build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -992,11 +1023,13 @@ tensorflow::ImportNumpy(); %unignore xla::swig::XrtAllocation; %unignore xla::swig::XrtAllocation::FromLiteral; %unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; %unignore xla::swig::XrtAllocationTuple; %unignore xla::swig::XrtAllocationTuple::Release; %unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; %unignore xla::swig::CompiledXrtComputation; %unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 92b0685dbba195405d78867776fe43b5f6c60f4c..c91a2aaf56dfe2127168628c78e0c4b868a28055 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -26,6 +26,9 @@ import os import numpy as np +import six +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 @@ -75,6 +78,13 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -212,23 +222,33 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend): + def __init__(self, c_buffer, backend, replica): self.c_buffer = c_buffer self._backend = backend + self._replica = replica if backend.backend_type == BackendType.XRT: self._delete = c_api.DeleteXrtAllocation else: self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) + num_replicas = get_replica_count() + if not 0 <= replica < num_replicas: + raise ValueError( + 'Attempt to place buffer on replica {} when the replica count is {}' + .format(replica, num_replicas)) if backend.backend_type == BackendType.XRT: - cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target) + if replica != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + cbuf = c_api.XrtAllocation.FromLiteral( + pyval, _maybe_encode_string(backend.target)) else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None) - return LocalBuffer(cbuf, backend) + cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) + return LocalBuffer(cbuf, backend, replica) def to_py(self): return self.c_buffer.ToLiteral() @@ -236,6 +256,9 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) + def replica(self): + return self._replica + def delete(self): if self.c_buffer is not None: self._delete(self.c_buffer) @@ -245,14 +268,15 @@ class LocalBuffer(object): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None if self._backend.backend_type == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple(self.c_buffer, - self._backend.target) + result = c_api.DestructureXrtAllocationTuple( + self.c_buffer, _maybe_encode_string(self._backend.target)) else: result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) self.delete() size = result.size() destructured = tuple( - LocalBuffer(result.Release(i), backend=self._backend) + LocalBuffer( + result.Release(i), replica=self._replica, backend=self._backend) for i in xrange(size)) return destructured @@ -322,6 +346,9 @@ class Shape(object): def __ne__(self, other): return not self == other + def __hash__(self): + return hash((self._dtype, self._dimensions, self._minor_to_major)) + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' '_is_tuple={!r}, _minor_to_major={!r})').format( @@ -541,10 +568,13 @@ class LocalComputation(object): ] result_shape = result_shape.map_leaves(layout_fn) + argument_shapes = list(argument_shapes) + compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape if self._backend.backend_type == BackendType.XRT: - c = self.computation.CompileForXrt(argument_shapes, self._backend.target) + c = self.computation.CompileForXrt( + argument_shapes, _maybe_encode_string(self._backend.target)) else: c = self.computation.Compile(argument_shapes, compile_options) return LocalComputation(c, is_compiled=True, backend=self._backend) @@ -558,23 +588,87 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=()): - """Execute with LocalBuffer arguments and return value.""" + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): + """Execute on one replica with LocalBuffer arguments and return value.""" + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) + + def ExecutePerReplica(self, arguments=None): + """Execute on many replicas with LocalBuffer arguments and return value. + + Args: + arguments: A sequence of sequences of LocalBuffers. The i'th inner + sequence comprises the arguments for execution on the i'th replica. + + Returns: + A list of the computation's outputs on each replica, as a LocalBuffer. If + a shallow sequence of arguments was passed in for `arguments`, then the + sole, zero'th replica's output is returned instead, as a LocalBuffer. + """ if not self._is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - arguments = tuple(arguments) - if any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - return LocalBuffer( - self._c_computation.Execute([arg.c_buffer for arg in arguments]), - backend=self._backend) + if arguments is None: + arguments = ((),) * get_replica_count() + else: + arguments = [list(replica_args) for replica_args in arguments] + + # Check arguments + for replica, replica_args in enumerate(arguments): + for arg in replica_args: + if arg.is_deleted(): + raise ValueError('Executing with deleted local buffer argument') + if arg.replica() != replica: + raise ValueError( + 'Executing on replica {} with argument from replica {}'.format( + replica, arg.replica())) + + # Pull out argument buffer handles + stripped_args = [ + [arg.c_buffer for arg in replica_args] for replica_args in arguments + ] + + # Execute + if self._backend.backend_type == BackendType.XRT: + if len(stripped_args) > 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + output_buffers = [self._c_computation.Execute(stripped_args[0])] + else: + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) + size = output_buffer_tup.size() + output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + + # Wrap output handles in LocalBuffer instances + return tuple( + LocalBuffer(output_buffer, backend=self._backend, replica=replica) + for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): - """Execute with Python values as arguments and return value.""" - arguments = tuple( - LocalBuffer.from_pyval(arg, backend=self._backend) for arg in arguments) + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): + return LocalBuffer.from_pyval(arg, backend=self._backend) + + arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() + def ExecuteWithPythonValuesPerReplica(self, arguments): + """Execute on many replicas with Python values as arguments and output.""" + + def put(arg, replica): + return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + + arguments = [[put(arg, replica) + for arg in replica_args] + for replica, replica_args in enumerate(arguments)] + return [out.to_py() for out in self.ExecutePerReplica(arguments)] + def __del__(self): self._delete(self._c_computation) @@ -761,8 +855,7 @@ class ComputationBuilder(object): Returns: A LocalOp representing the added broadcast-in-dimensions op. """ - xla_shape = Shape.array_shape(self.GetShape(operand).element_type(), shape) - return self._client.BroadcastInDim(operand, xla_shape, broadcast_dimensions) + return self._client.BroadcastInDim(operand, shape, broadcast_dimensions) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -1380,6 +1473,7 @@ def initialize_platform_name(platform_name): Raises: A runtime exception if the XLA service has already been initialized. """ + platform_name = _maybe_encode_string(platform_name) c_api.InitializePlatformName(platform_name) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index f158f6b2410352432445f669155aff0af5526abf..95b2bf300ec67e9f034f77450416544cb088ae55 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -25,9 +25,10 @@ from tensorflow.compiler.xla.python_api import types class Shape(object): - """Wraps a xla_data_pb2.Shape message with a convenient Python type. + """Wraps a xla_data_pb2.ShapeProto message with a convenient Python type. - Provides direct access to the underlying xla_data_pb2.Shape message in the + Provides direct access to the underlying xla_data_pb2.ShapeProto message in + the message attribute, along with accessor wrappers to the message's fields. Avoid direct access to .message unless interacting directly with protobuf APIs like CopyFrom. In other words, prefer hauling the shape around in a Shape, and @@ -48,7 +49,7 @@ class Shape(object): Raises: ValueError: if element_type is TUPLE but dimensions are not Shape objects. """ - self.message = xla_data_pb2.Shape() + self.message = xla_data_pb2.ShapeProto() self.message.element_type = element_type if element_type == xla_data_pb2.TUPLE: if not all(isinstance(subshape, Shape) for subshape in dimensions): diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 3abb3855a42b8b5222115262448d359da3a80e87..26affbcceb33110baf41d507173e56f8b1c8c9eb 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -16,7 +16,6 @@ xla_proto_library( use_grpc_plugin = True, visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", ], ) diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index e4f332cda22cc5b889bf73f06913b96d6091dc81..0ff8adc2acbe5fd21e85027dd63bfb14f5672a7d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -43,7 +43,6 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; package xla; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 19b5c1ca25debf80c7e712854b47384937697d3d..4c21ae2a427477caa86fb4130616c38eb3bcf006 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -281,10 +281,12 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -292,6 +294,7 @@ cc_library( name = "hlo", srcs = [ "dfs_hlo_visitor.cc", + "dynamic_parameter_binding.cc", "hlo_computation.cc", "hlo_input_output_alias_config.cc", "hlo_instruction.cc", @@ -305,6 +308,7 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "dynamic_parameter_binding.h", "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", @@ -350,6 +354,25 @@ cc_library( ], ) +tf_cc_test( + name = "dynamic_parameter_binding_test", + srcs = ["dynamic_parameter_binding_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], @@ -387,9 +410,36 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pattern_matcher_gmock", + testonly = 1, + hdrs = ["pattern_matcher_gmock.h"], + deps = [ + ":pattern_matcher", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "pattern_matcher_gmock_test", + srcs = ["pattern_matcher_gmock_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -403,6 +453,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], @@ -1336,6 +1387,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1539,7 +1591,10 @@ tf_cc_test( ":hlo", ":hlo_casting_utils", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1707,7 +1762,9 @@ cc_library( ":hlo", ":hlo_pass", ":hlo_query", + ":pattern_matcher", ":while_loop_analysis", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1720,9 +1777,14 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":algebraic_simplifier", ":hlo", + ":hlo_cse", ":hlo_dce", ":hlo_matchers", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1848,6 +1910,41 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_inference", + srcs = ["dynamic_dimension_inference.cc"], + hdrs = ["dynamic_dimension_inference.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "dynamic_dimension_inference_test", + srcs = ["dynamic_dimension_inference_test.cc"], + deps = [ + ":dynamic_dimension_inference", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], @@ -2005,7 +2102,8 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", - ":hlo_matchers", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2347,6 +2445,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -2598,8 +2697,9 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", - ":hlo_matchers", ":layout_assignment", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -2610,6 +2710,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/types:span", @@ -2744,6 +2845,8 @@ tf_cc_test( ":hlo_matchers", ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2855,6 +2958,46 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_get_dimension_size_rewriter", + srcs = ["hlo_get_dimension_size_rewriter.cc"], + hdrs = ["hlo_get_dimension_size_rewriter.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "hlo_get_dimension_size_rewriter_test", + srcs = ["hlo_get_dimension_size_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_get_dimension_size_rewriter", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "device_memory_allocator", srcs = [ @@ -2913,6 +3056,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", @@ -3026,6 +3170,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", + ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3318,9 +3463,9 @@ cc_library( ":tuple_util", ":while_loop_analysis", ":while_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3463,6 +3608,8 @@ tf_cc_test( ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", @@ -3513,6 +3660,41 @@ cc_library( ], ) +cc_library( + name = "ar_crs_combiner", + srcs = ["ar_crs_combiner.cc"], + hdrs = ["ar_crs_combiner.h"], + deps = [ + ":call_graph", + ":pattern_matcher", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "ar_crs_combiner_test", + srcs = ["ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 89e62bd2f0dc02d2d0947ae47e3bb0c9955f103e..985c5af1c4d89425dd6693585e42e22510fe21f8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -68,6 +69,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -84,7 +124,8 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. bool ReshapeOrCopyIsBitcast( const HloInstruction* instr, - const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + const AlgebraicSimplifierOptions::ValidBitcastCallback& + valid_bitcast_callback) { CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); @@ -95,6 +136,11 @@ bool ReshapeOrCopyIsBitcast( valid_bitcast_callback(operand->shape(), instr->shape()); } +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -180,21 +226,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const bool changed() const { return changed_; } // Runs the visitor on a computation. - static bool Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification); + static bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options); private: - explicit AlgebraicSimplifierVisitor( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) - : computation_(computation), - is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + explicit AlgebraicSimplifierVisitor(HloComputation* computation, + const AlgebraicSimplifierOptions& options) + : computation_(computation), options_(options) {} // Transforms Dots where at least one input is a vector or has a degenerate // dimension and converts it into a multiply and reduce. This should enable @@ -233,10 +271,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* new_instruction); // Returns whether the shape of the output of the given instructions are the - // same for the purposes of simplification. If is_layout_sensitive_ is true, - // then this tests shape equality including layout (ShapeUtil::Equal). If - // is_layout_sensitive_ is false, then the tests shape compatibility - // (ShapeUtil::Compatible). + // same for the purposes of simplification. If options_.is_layout_sensitive() + // is true, then this tests shape equality including layout + // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the + // tests shape compatibility (ShapeUtil::Compatible). bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; // Returns whether it was possible to transform `root` to a clamp instruction. @@ -325,22 +363,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // traversing. HloComputation* computation_; + // The backend-specific options selected for the algebraic simplifier. + const AlgebraicSimplifierOptions& options_; + // Whether algebraic simplification has occurred. bool changed_ = false; - // Whether layout is considered during transformation. - bool is_layout_sensitive_; - - // Callback used to determine if a bitcast is possible. - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - - // Disable dot strength reduction on platforms where it causes a slowdown. - bool enable_dot_strength_reduction_; - - // Disable convolution -> dot simplification on platforms where it causes a - // slowdown. - bool enable_conv_simplification_; - // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; }; @@ -348,19 +376,15 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { } // namespace bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) { - AlgebraicSimplifierVisitor visitor( - computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_strength_reduction, enable_conv_simplification); + HloComputation* computation, const AlgebraicSimplifierOptions& options) { + AlgebraicSimplifierVisitor visitor(computation, options); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const { - if (is_layout_sensitive_) { + if (options_.is_layout_sensitive()) { return ShapeUtil::Equal(lhs->shape(), rhs->shape()); } else { return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); @@ -431,6 +455,40 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } return Status::OK(); } @@ -504,8 +562,8 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { ReplaceWithBitcast(copy); } @@ -541,7 +599,74 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( VLOG(10) << "trying to replace " << concatenate->ToString() << " with " << replacement->ToString(); ReplaceInstructionIfSameShape(concatenate, replacement); - } else if (operands.size() == 2) { + return Status::OK(); + } + + // Check if we can merge "adjacent" slice operands which take slices from the + // same other op. For simplicity we only merge unstrided slices. + int64 concatenate_dimension = concatenate->concatenate_dimension(); + for (int64 i = 0; i < operands.size(); ++i) { + if (operands[i]->opcode() != HloOpcode::kSlice || + !IsUnstridedSlice(operands[i])) { + continue; + } + int64 slice_end = operands[i]->slice_limits(concatenate_dimension); + HloInstruction* slice_operand = operands[i]->mutable_operand(0); + int64 j = i + 1; + while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice && + IsUnstridedSlice(operands[j]) && + operands[j]->operand(0) == slice_operand && + operands[j]->slice_starts(concatenate_dimension) == slice_end) { + // Check that all the slice_start values are the same in all other + // dimensions. This implies that the slice_limit values are also the same, + // because operands of concatenate need to have the same shape, and we + // already checked that the slices are unstrided. + bool same_other_starts = true; + for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) { + if (k == concatenate_dimension) { + continue; + } + if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) { + same_other_starts = false; + break; + } + } + if (!same_other_starts) { + break; + } + slice_end = operands[j]->slice_limits(concatenate_dimension); + ++j; + } + if (j - i > 1) { + Shape new_slice_shape = operands[i]->shape(); + new_slice_shape.set_dimensions( + concatenate_dimension, + slice_end - operands[i]->slice_starts(concatenate_dimension)); + auto new_limit_indices = operands[i]->slice_limits(); + new_limit_indices[concatenate_dimension] = slice_end; + auto new_slice_op = + computation_->AddInstruction(HloInstruction::CreateSlice( + new_slice_shape, slice_operand, + /*start_indices=*/operands[i]->slice_starts(), + /*limit_indices=*/new_limit_indices, + /*strides=*/operands[i]->slice_strides())); + std::vector new_operands; + for (int64 k = 0; k < i; ++k) { + new_operands.push_back(operands[k]); + } + new_operands.push_back(new_slice_op); + for (int64 k = j; k < operands.size(); ++k) { + new_operands.push_back(operands[k]); + } + auto replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), new_operands)); + ReplaceInstructionIfSameShape(concatenate, replacement); + return Status::OK(); + } + } + + if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. bool is_effective_low_pad = Match( @@ -557,7 +682,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); padding_config_dim->set_interior_padding(0); - if (dim == concatenate->concatenate_dimension()) { + if (dim == concatenate_dimension) { if (is_effective_low_pad) { padding_config_dim->set_edge_padding_low( operands[0]->shape().dimensions(dim)); @@ -1215,7 +1340,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); if (did_strength_reduction) { @@ -1619,6 +1745,27 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); } + + // Interior padding on one sized dimensions have no effect. As a result it + // makes other simplifications possible if there is no interior padding. + if (HasInteriorPadding(pad->padding_config())) { + PaddingConfig padding_config = pad->padding_config(); + bool cleared_interior_padding = false; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + if (padding_config.dimensions(i).interior_padding() > 0 && + pad->operand(0)->shape().dimensions(i) == 1) { + cleared_interior_padding = true; + padding_config.mutable_dimensions(i)->set_interior_padding(0); + } + } + if (cleared_interior_padding) { + return ReplaceWithNewInstruction( + pad, + HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0), + pad->mutable_operand(1), padding_config)); + } + } + // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1910,8 +2057,8 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -2030,11 +2177,6 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( return false; } -bool IsUnstridedSlice(const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); -} - StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( HloInstruction* slice) { CHECK_EQ(slice->opcode(), HloOpcode::kSlice); @@ -2501,6 +2643,108 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } + if (!options_.enable_permutation_sort_replacement()) { + return Status::OK(); + } + // Check if we are sorting a permutation. In that case, we know that the keys + // will be sorted to the identity permutation, and we can represent the + // changes to the 'values' parameter as a scatter. + if (sort->operand_count() == 2 && + operand->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* other_sort = operand->operand(0); + // Check whether the 'values' parameter is the result of another sort with + // the same sort dimension. + if (other_sort->opcode() == HloOpcode::kSort && + other_sort->operand_count() >= 2 && + other_sort->dimensions(0) == dimension_to_sort && + other_sort->operand(operand->tuple_index())->opcode() == + HloOpcode::kIota) { + auto* iota = + Cast(other_sort->operand(operand->tuple_index())); + // The sort operand needs to be an integral iota, and the iota dimension + // needs to be the dimension that was sorted. + if (iota->iota_dimension() == dimension_to_sort && + ShapeUtil::ElementIsIntegral(iota->shape())) { + // We use the following construction method for a Scatter that applies + // the permutation from 'keys' to the 'values' parameter. + // - Take the "keys" parameter of the second sort and reshape it to have + // another "1" dimension at the end. + // - Concatenate it with iotas of the same extended shape with all + // different iota_dimensions except the dimension_to_sort in the order + // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and + // dimension_to_sort = 1, we would have concatenate of (iota with + // iota_dimension=0, keys, iota with iota_dimension = 2) + // - Use this as the indices parameter of scatter, and set updates + // of the scatter to be a reshaped 'values' parameter of sort (adding + // 'rank' many 1 dimensions at the end). + int64 rank = ShapeUtil::Rank(operand->shape()); + Shape extended_shape = operand->shape(); + extended_shape.add_dimensions(1); + extended_shape.mutable_layout()->add_minor_to_major(rank); + auto reshaped_permutation = computation_->AddInstruction( + HloInstruction::CreateReshape(extended_shape, operand)); + std::vector concat_operands; + for (int64 i = 0; i < rank; ++i) { + if (i == dimension_to_sort) { + concat_operands.push_back(reshaped_permutation); + } else { + concat_operands.push_back(computation_->AddInstruction( + HloInstruction::CreateIota(extended_shape, i))); + } + } + Shape concat_shape = operand->shape(); + concat_shape.add_dimensions(rank); + concat_shape.mutable_layout()->add_minor_to_major(rank); + auto scatter_indices = + rank > 1 ? computation_->AddInstruction( + HloInstruction::CreateConcatenate( + concat_shape, concat_operands, rank)) + : reshaped_permutation; + + // We don't care about the operand, it will be completely overridden by + // the updates. + auto scatter_operand = computation_->AddInstruction( + HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); + + // Construct the updates operand of scatter. + Shape update_shape = sort->operand(1)->shape(); + for (int64 i = 0; i < rank; ++i) { + update_shape.add_dimensions(1); + update_shape.mutable_layout()->add_minor_to_major(rank + i); + } + auto scatter_updates = + computation_->AddInstruction(HloInstruction::CreateReshape( + update_shape, sort->mutable_operand(1))); + + // Construct the updates computation, which simply replaces the operand + // values with the update values. + HloComputation::Builder b("update_replace_computation"); + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); + auto update_replace_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); + + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(rank); + for (int64 i = 0; i < rank; ++i) { + dim_numbers.add_update_window_dims(rank + i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + auto scatter = + computation_->AddInstruction(HloInstruction::CreateScatter( + sort->operand(1)->shape(), scatter_operand, scatter_indices, + scatter_updates, update_replace_computation, dim_numbers)); + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple( + {computation_->AddInstruction(HloInstruction::CreateIota( + operand->shape(), dimension_to_sort)), + scatter})); + } + } + } return Status::OK(); } @@ -2525,7 +2769,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } @@ -2674,13 +2918,13 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (!enable_conv_simplification_) { + if (!options_.enable_conv_simplification()) { return false; } // TODO(b/31337498): For now, we cowardly refuse to do this optimization in // layout-insensitive mode, for fear of adding nontrivial reshapes. - if (!is_layout_sensitive_) { + if (!options_.is_layout_sensitive()) { return false; } @@ -2770,9 +3014,9 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(input_shape, new_input_shape) || - !valid_bitcast_callback_(filter_shape, new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || + !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || + !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { return false; } @@ -2878,9 +3122,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run( - comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_strength_reduction_, enable_conv_simplification_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f..d2775b9fafa7e4c625f5d181114e80e7369f9c78 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,8 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloModulePass { +class AlgebraicSimplifierOptions { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform @@ -34,18 +33,63 @@ class AlgebraicSimplifier : public HloModulePass { using ValidBitcastCallback = std::function; + explicit AlgebraicSimplifierOptions( + ValidBitcastCallback valid_bitcast_callback) + : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + // If valid_bitcast_callback returns true, then the pass will replace reshapes + // and transposes with bitcasts. + const ValidBitcastCallback& valid_bitcast_callback() const { + return valid_bitcast_callback_; + } + + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + void set_is_layout_sensitive(bool is_layout_sensitive) { + is_layout_sensitive_ = is_layout_sensitive; + } + bool is_layout_sensitive() const { return is_layout_sensitive_; } + + // Enable dot simplification on platforms where it is profitable. + void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { + enable_dot_strength_reduction_ = enable_dot_strength_reduction; + } + bool enable_dot_strength_reduction() const { + return enable_dot_strength_reduction_; + } + + // Enable convolution simplification on platforms where it is profitable. + void set_enable_conv_simplification(bool enable_conv_simplification) { + enable_conv_simplification_ = enable_conv_simplification; + } + bool enable_conv_simplification() const { + return enable_conv_simplification_; + } + + // If enable_permutation_sort_replacement is true, a sort op that is known to + // sort a permutation will be replaced with a scatter op. + void set_enable_permutation_sort_replacement( + bool enable_permutation_sort_replacement) { + enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + } + bool enable_permutation_sort_replacement() const { + return enable_permutation_sort_replacement_; + } + + private: + ValidBitcastCallback valid_bitcast_callback_; + bool is_layout_sensitive_{false}; + bool enable_dot_strength_reduction_{true}; + bool enable_conv_simplification_{true}; + bool enable_permutation_sort_replacement_{false}; +}; + +// A pass which performs algebraic simplifications. +class AlgebraicSimplifier : public HloModulePass { + public: // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. If valid_bitcast_callback - // returns true, then the pass will replace reshapes and transposes with - // bitcasts. - AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction = true, - bool enable_conv_simplification = true) - : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + // transformation. Otherwise, layout is ignored. + explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) + : options_(options) {} ~AlgebraicSimplifier() override = default; absl::string_view name() const override { return "algsimp"; } @@ -54,14 +98,7 @@ class AlgebraicSimplifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; - ValidBitcastCallback valid_bitcast_callback_; - - // Enable dot simplification on platforms where it is profitable. - bool enable_dot_strength_reduction_; - - // Enable convolution simplification on platforms where it is profitable. - bool enable_conv_simplification_; + AlgebraicSimplifierOptions options_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 0000000000000000000000000000000000000000..5da13da041b4ded813876af7ca379025187545ab --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e4c4da1b0e7aef0e3476e4d232e410da25794e13..14ce519b6a0fd221070006d336d23bddeb6cd621 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -42,18 +44,20 @@ namespace xla { namespace { using ::testing::ElementsAre; +namespace m = match; -namespace op = xla::testing::opcode_matchers; - -AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } -AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase { + protected: + AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -70,13 +74,134 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { auto m = CreateNewVerifiedModule(); @@ -92,8 +217,7 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } @@ -115,8 +239,7 @@ TEST_F(AlgebraicSimplifierTest, SelectTrue) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -138,8 +261,7 @@ TEST_F(AlgebraicSimplifierTest, SelectFalse) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } @@ -159,8 +281,7 @@ TEST_F(AlgebraicSimplifierTest, SelectIdentical) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } @@ -196,11 +317,10 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero)))); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } @@ -219,11 +339,10 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant()))); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. @@ -246,11 +365,12 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); + EXPECT_THAT(root, GmockMatch(m::Add( + m::Op().Is(param0), + m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { @@ -269,8 +389,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -306,11 +425,11 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Op().Is(zero))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -329,8 +448,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -344,12 +462,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } @@ -361,12 +478,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { @@ -377,12 +493,11 @@ TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); } // Test that A - 0 is simplified to A @@ -400,8 +515,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -422,11 +536,11 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Negate(m::Op().Is(constant))))); } // Test that (A/B)/C is simplified to A/(B*C). @@ -448,14 +562,16 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Divide(param0, param1), param2)); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Multiply(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/(B/C) is simplified to (A*C)/B. @@ -476,15 +592,18 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Divide(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Divide(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Multiply(param0, param2), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)), + m::Parameter(1)))); } // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). @@ -511,15 +630,16 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Divide(m::Parameter(2), m::Parameter(3))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -539,14 +659,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Exp(param1))); + GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Exp(op::Negate(param1)))); + GmockMatch(m::Multiply(m::Parameter(0), + m::Exp(m::Negate(m::Parameter(1)))))); } // Test that A/pow(B,C) is simplified to A*pow(B,-C). @@ -567,15 +687,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // Test that broadcasting is done on the right step when simplifying A/pow(B,C) @@ -597,15 +720,18 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // A / Const => A * InvertedConst @@ -623,12 +749,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Constant())); + GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } // pow(pow(A, X), Y) => pow(A, X*Y) @@ -648,11 +773,12 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { inner_power, exp2)); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Power(base, op::Multiply(exp1, exp2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Power(m::Op().Is(base), + m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex @@ -673,8 +799,7 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { inner_power, exp2)); m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } @@ -693,8 +818,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -715,8 +839,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -740,8 +863,7 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -765,8 +887,7 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -790,8 +911,7 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); @@ -818,11 +938,10 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param1, param2)); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2)))); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -843,15 +962,16 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Exp(param0), op::Exp(param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Subtract(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1))))); } // Test that exp(A)*exp(B) is simplified to exp(A+B) @@ -873,14 +993,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Exp(param0), op::Exp(param1))); + GmockMatch(m::Multiply(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Add(param0, param1))); + GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1))))); } // Test that pow(exp(A), B) is simplified to exp(A*B) @@ -900,14 +1020,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Power(op::Exp(param0), param1)); + GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Multiply(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1))))); } // Test that ln(pow(A, B)) is simplified to ln(A)*B @@ -927,14 +1047,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Power(param0, param1))); + GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Log(param0), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); } // Test that ln(exp(A)) is simplified to A @@ -951,10 +1071,10 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Exp(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); @@ -981,13 +1101,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1)))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1)))); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -1005,14 +1126,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_EQ(root->literal().GetFirstElement(), 1); } @@ -1030,14 +1151,14 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -1059,10 +1180,10 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); @@ -1082,13 +1203,14 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } // Test that pow(A, -1) is simplified to 1/A. @@ -1105,14 +1227,14 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0)))); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), 1); @@ -1153,13 +1275,12 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); m->AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Convolution(lhs, rhs)); + GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs)))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { @@ -1196,13 +1317,12 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); m->AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::ReduceWindow(param, op::Constant())); + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { @@ -1225,12 +1345,11 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { padding)); m->AddEntryComputation(builder.Build()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Pad(param, op::Constant())); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + GmockMatch(m::Pad(m::Parameter(0), m::Constant()))); + HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1251,10 +1370,9 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { m->AddEntryComputation(std::move(computation)); EXPECT_THAT(m->entry_computation()->root_instruction(), - op::Reshape(op::Broadcast(op::Reshape(op)))); + GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op)))))); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + HloPassFix simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(m->entry_computation()->root_instruction(), op); @@ -1271,10 +1389,10 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert(m::Op().Is(input)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); @@ -1292,10 +1410,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); @@ -1314,19 +1432,24 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier1(options); ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options2(bitcasting_callback()); + options2.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier2(options2); ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } // Test that unary concatenates are removed. @@ -1341,10 +1464,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); @@ -1371,16 +1494,17 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT( - computation->root_instruction(), - op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate( + m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0), + m::Op().Is(empty_slice), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(param0, param0, param1)); + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0), + m::Parameter(1)))); } // Test that reduce of concat is simplified. @@ -1423,14 +1547,14 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), - op::Reduce(param2, zero))); + GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)), + m::Reduce(m::Parameter(1), m::Op().Is(zero))), + m::Reduce(m::Parameter(2), m::Op().Is(zero))))); } // Test a concatenate with only empty operands is removed. @@ -1453,10 +1577,10 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(empty_literal, empty_slice)); + GmockMatch(m::Concatenate(m::Op().Is(empty_literal), + m::Op().Is(empty_slice)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); @@ -1479,10 +1603,80 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { + auto m = CreateNewVerifiedModule(); + Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + + HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0}, + /*limit_indices=*/{50, 10}, /*strides=*/{1, 1})); + + // Cannot merge 'slice0' and 'slice1' because of different start indices in + // dimension 0. + HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10}, + /*limit_indices=*/{100, 20}, /*strides=*/{1, 1})); + + // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2. + HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20}, + /*limit_indices=*/{100, 40}, /*strides=*/{1, 2})); + + // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2. + HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40}, + /*limit_indices=*/{100, 50}, /*strides=*/{1, 1})); + + // Can merge 'slice3' and 'slice4'. + HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50}, + /*limit_indices=*/{100, 60}, /*strides=*/{1, 1})); + + // Can merge 'slice4' and 'slice5'. + HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60}, + /*limit_indices=*/{100, 70}, /*strides=*/{1, 1})); + + // Cannot merge 'slice5' and 'slice6' because of overlap. + HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69}, + /*limit_indices=*/{100, 79}, /*strides=*/{1, 1})); + + // Cannot merge 'slice6' and 'slice7' because of slicing from a different + // parameter. + HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79}, + /*limit_indices=*/{100, 89}, /*strides=*/{1, 1})); + + builder.AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, + {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1)); + auto computation = m->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + auto s = m::Slice(m::Parameter(0)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1))))); + // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its + // shape should have dimensions {50, 30}. + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(), + ShapeUtil::MakeShape(F32, {50, 30}))); + EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40); } // Test that a simplification which changes layouts is not performed if layout @@ -1502,14 +1696,17 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } // Test that a simplification which preserves layouts is performed if layout @@ -1529,10 +1726,12 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. @@ -1557,14 +1756,17 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } // Test transforming reshapes and transposes of rng. @@ -1588,13 +1790,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // Verify that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. - EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng())); EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), reshape_shape)); } @@ -1636,17 +1838,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(transformable_reshape, dimensions_wrong_reshape, - layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Op().Is(transformable_reshape), + m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( computation->root_instruction(), - op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1667,8 +1872,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1692,8 +1897,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); m->AddEntryComputation(builder.Build()); EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -1715,14 +1920,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -1742,14 +1950,17 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -1769,13 +1980,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Reshape(param0))); + GmockMatch(m::Reshape(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { @@ -1796,13 +2007,16 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -1821,13 +2035,14 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Op().Is(transpose1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); } @@ -1846,13 +2061,13 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Broadcast(op::Reshape(param0))); + GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } // Test merging broadcast and reshape. @@ -1869,13 +2084,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param0))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -1891,14 +2106,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -1914,13 +2128,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(computation->root_instruction()->dimensions(), ::testing::ElementsAre(3)); } @@ -1938,13 +2152,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); @@ -1964,14 +2178,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { @@ -1984,13 +2197,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } @@ -2004,14 +2217,13 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); @@ -2027,13 +2239,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { @@ -2046,13 +2259,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_EQ(Cast(computation->root_instruction()) ->iota_dimension(), 3); @@ -2068,13 +2281,13 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); const int64 iota_dim = Cast(computation->root_instruction()) ->iota_dimension(); @@ -2091,13 +2304,14 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -2120,10 +2334,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2153,8 +2367,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); auto has_negative_padding = [](const HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2166,16 +2379,54 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero))))); EXPECT_FALSE( has_negative_padding(computation->root_instruction()->operand(0))); } +TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) { + // Verify that a pad instruction with interior padding on one-sized + // dimensions, removes the interior padding. + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 1}), "param")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + auto dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(3); + dimension->set_edge_padding_high(3); + dimension->set_interior_padding(i * 3); + } + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + + ASSERT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + ASSERT_TRUE(HasInteriorPadding(pad->padding_config())); + + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + EXPECT_FALSE( + HasInteriorPadding(computation->root_instruction()->padding_config())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2187,10 +2438,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2210,10 +2461,10 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); @@ -2239,13 +2490,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Slice(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); @@ -2271,13 +2523,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Slice(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { @@ -2296,10 +2549,10 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -2312,12 +2565,84 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } +TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::Iota(), + m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), + m::Reshape())))); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { + // Same as ReplacePermutationSortWithScatter except that the iota has F32 + // type. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { + // Same as ReplacePermutationSortWithScatter except that the sort dimensions + // don't match. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -2334,11 +2659,11 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { keys, {values0, values1})); auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(keys, values0, values1)); + GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0), + m::Op().Is(values1)))); } // Test that A && True is simplified to A @@ -2356,8 +2681,7 @@ TEST_F(AlgebraicSimplifierTest, AndTrue) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2378,8 +2702,7 @@ TEST_F(AlgebraicSimplifierTest, AndTrue2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2400,8 +2723,7 @@ TEST_F(AlgebraicSimplifierTest, AndFalse) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_false); @@ -2422,8 +2744,7 @@ TEST_F(AlgebraicSimplifierTest, AndFalse2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAnd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_false); @@ -2444,8 +2765,7 @@ TEST_F(AlgebraicSimplifierTest, OrTrue) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_true); @@ -2466,8 +2786,7 @@ TEST_F(AlgebraicSimplifierTest, OrTrue2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, const_true); @@ -2488,8 +2807,7 @@ TEST_F(AlgebraicSimplifierTest, OrFalse) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2510,8 +2828,7 @@ TEST_F(AlgebraicSimplifierTest, OrFalse2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kOr); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); @@ -2641,15 +2958,15 @@ TEST_P(ConvInputPaddingTest, DoTest) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrCat("size=3x3 ", testcase.expected_conv_window)); } @@ -2759,15 +3076,15 @@ TEST_P(ConvFilterPaddingTest, DoIt) { auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrFormat("size=%dx%d %s", conv->operand(1)->shape().dimensions(2), @@ -2908,8 +3225,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + simplifier_options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } @@ -3032,17 +3350,15 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { EXPECT_EQ(root, slice); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(scalar_param)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(scalar_param)) + .WithShapeEqualTo(&slice_shape))); } // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a @@ -3071,13 +3387,11 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { EXPECT_EQ(root, reshape); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(forty_two)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(forty_two)) + .WithShapeEqualTo(&reshape_shape))); } // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). @@ -3138,8 +3452,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -3147,7 +3460,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_THAT(root, + GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3224,8 +3538,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -3233,7 +3546,8 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)), + m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -3258,8 +3572,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); @@ -3295,8 +3608,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { m->AddEmbeddedComputation(std::move(dot_computation)); m->AddEntryComputation(call_builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } @@ -3313,11 +3625,10 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); + GmockMatch(m::Tuple(m::Constant(), m::Constant()))); } // A dynamic-slice is trivial if its start indices are all zeroes and the size @@ -3337,10 +3648,9 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { /*slice_sizes=*/{10, 100, 1000})); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Parameter()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } // A dynamic-update-slice is trivial if its start indices are all zeroes and the @@ -3371,11 +3681,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Parameter(), op::Parameter())); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. @@ -3394,11 +3703,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_THAT(root->dimensions(), ElementsAre(2)); } @@ -3421,11 +3729,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } @@ -3442,11 +3749,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3464,11 +3770,10 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } @@ -3486,11 +3791,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { @@ -3507,11 +3812,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { @@ -3528,8 +3833,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -3547,11 +3852,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter()); + EXPECT_THAT(root, GmockMatch(m::Parameter())); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { @@ -3569,11 +3874,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(1)); + EXPECT_THAT(root, GmockMatch(m::Parameter(1))); } TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { @@ -3591,11 +3896,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2)))); EXPECT_EQ(root->slice_starts(0), 1); EXPECT_EQ(root->slice_limits(0), 2); } @@ -3613,11 +3918,11 @@ TEST_F(AlgebraicSimplifierTest, NegateNegate) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } TEST_F(AlgebraicSimplifierTest, NotNot) { @@ -3633,11 +3938,11 @@ TEST_F(AlgebraicSimplifierTest, NotNot) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Parameter(0)); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); } struct PadReduceWindowEffectiveBroadcastCase { @@ -3733,8 +4038,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { output_shape, pad, zero, window, add_computation)); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); @@ -3742,10 +4046,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); if (param.should_become_broadcast) { - EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast())); } else { EXPECT_THAT(computation->root_instruction(), - op::ReduceWindow(::testing::_, zero)); + GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero)))); } } @@ -3815,8 +4119,7 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = @@ -3845,7 +4148,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // Test that we transform @@ -3893,19 +4196,19 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); - auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); - auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); + auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0)); + auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1)); + auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2))); } // Test that we transform @@ -3958,20 +4261,20 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); - auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); - auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); - auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), - match_dot_3)); + auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant())); + auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant())); + auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant())); + auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant())); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3))); } DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { @@ -4000,8 +4303,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { const HloComputation* const computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -4021,7 +4323,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -4078,8 +4380,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -4090,8 +4391,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } @@ -4149,8 +4450,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -4161,8 +4461,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc new file mode 100644 index 0000000000000000000000000000000000000000..362bc44a1cf377b51c5519c6ab5e0d9628e80e58 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -0,0 +1,285 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace m = match; + +// If the argument instruction is a CRS in the sequence +// AR -> Convert -> Add -> CRS +// then return the AR in the sequence. +// TODO(b/117554291): Rewrite this to recognize more general patterns, +// not just the specific one of AR -> Add -> Convert -> CRS. +absl::optional MatchesArCrsPattern( + HloInstruction* instruction) { + HloInstruction *ar, *convert, *add, *crs; + if (Match(instruction, + m::CrossReplicaSum( + &crs, m::Add(&add, m::Op(), + m::Convert(&convert, + m::CrossReplicaSum(&ar, m::Op()))))) && + ar->users().size() == 1 && ar->shape().element_type() == BF16 && + convert->shape().element_type() == F32 && !crs->all_reduce_id()) { + return ar; + } + return absl::optional(); +} + +} // namespace + +absl::optional ArCrsCombiner::WhileFromBodyParameter( + HloInstruction* instruction) { + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); + HloComputation* computation = instruction->parent(); + auto caller_instructions = call_graph_->GetComputationCallers(computation); + if (caller_instructions.size() == 1) { + auto caller_instruction = caller_instructions[0]; + if (caller_instruction->opcode() == HloOpcode::kWhile) { + return caller_instruction; + } + } + return absl::optional(); +} + +std::vector ArCrsCombiner::GetAllTuples( + HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kTuple) { + return {instruction}; + } + if (instruction->opcode() == HloOpcode::kDomain) { + return GetAllTuples(instruction->operands()[0]); + } + if (instruction->opcode() == HloOpcode::kParameter) { + auto maybe_while = WhileFromBodyParameter(instruction); + if (!maybe_while) { + return {}; + } + auto while_instr = *maybe_while; + auto init_tuples = GetAllTuples(while_instr->while_init()); + auto body_tuples = + GetAllTuples(while_instr->while_body()->root_instruction()); + if (init_tuples.empty() || body_tuples.empty()) { + return {}; + } + init_tuples.insert(init_tuples.end(), body_tuples.begin(), + body_tuples.end()); + return init_tuples; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + std::vector result_tuples; + for (auto tuple : GetAllTuples(instruction->operands()[0])) { + auto tmp_tuples = + GetAllTuples(tuple->mutable_operand(instruction->tuple_index())); + if (tmp_tuples.empty()) { + return {}; + } + result_tuples.insert(result_tuples.end(), tmp_tuples.begin(), + tmp_tuples.end()); + } + return result_tuples; + } + return {}; +} + +bool ArCrsCombiner::TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs) { + auto tuples = GetAllTuples(tuple_shaped_instruction); + if (tuples.empty()) { + return false; + } + for (auto tuple : tuples) { + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); + if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), + tuple->mutable_operand(i2), + visited_pairs)) { + return false; + } + } + return true; +} + +/* static */ +bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2) { + ArCrsCombiner combiner(/*num_spatial_partitions=*/2); + auto module = i1->parent()->parent(); + CHECK_EQ(module, i2->parent()->parent()); + combiner.call_graph_ = CallGraph::Build(module); + absl::flat_hash_map visited_pairs; + return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs); +} + +bool ArCrsCombiner::InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs) { + if (i1 == i2) { + return true; + } + auto uid1 = i1->unique_id(); + auto uid2 = i2->unique_id(); + auto min_uid = std::min(uid1, uid2); + auto max_uid = std::max(uid1, uid2); + auto it = visited_pairs->find(min_uid); + if (it != visited_pairs->end() && max_uid == it->second) { + return true; + } + auto opcode1 = i1->opcode(); + auto operands1 = i1->operands(); + if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { + return false; + } + visited_pairs->emplace(min_uid, max_uid); + for (int i = 0; i < operands1.size(); ++i) { + auto operand1 = operands1[i]; + auto operand2 = i2->operands()[i]; + if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) { + return false; + } + } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } + if (opcode1 == HloOpcode::kGetTupleElement) { + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + i2->tuple_index(), visited_pairs); + } + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); +} + +void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + auto ar = MatchesArCrsPattern(instruction); + if (ar) { + all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + } + } + } +} + +void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); + + auto instr_0 = instruction_vec[0]; + auto add_0 = instr_0->users()[0]->users()[0]; + CHECK_EQ(HloOpcode::kAdd, add_0->opcode()); + + for (int i = 1; i < instruction_vec.size(); ++i) { + auto instr_i = instruction_vec[i]; + auto add_i = instr_i->users()[0]->users()[0]; + CHECK_EQ(HloOpcode::kAdd, add_i->opcode()); + absl::flat_hash_map visited_pairs; + if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { + all_reduce_map_.erase(it.first); + } + } + } +} + +StatusOr ArCrsCombiner::RewriteGraph() { + if (all_reduce_map_.empty()) { + return false; + } + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + for (auto all_reduce : instruction_vec) { + auto parent_computation = all_reduce->parent(); + auto convert = all_reduce->users()[0]; + auto add = convert->users()[0]; + auto crs = add->users()[0]; + + if (!computation_is_addition(all_reduce->called_computations()[0]) || + !computation_is_addition(crs->called_computations()[0])) { + continue; + } + HloInstruction* other_summand = (add->operands()[0] == convert) + ? add->operands()[1] + : add->operands()[0]; + // To move the AR past the addition, we need to divide other_summand by + // the number of spatial partitions. + CHECK_EQ(all_reduce->user_count(), 1); + TF_CHECK_OK( + all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); + auto shape = other_summand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_summand, divisor)); + TF_CHECK_OK(other_summand->ReplaceUseWith(add, division)); + // The AllReduce and the CRS are combined to an all-core AllReduce. + crs->set_all_reduce_id(all_reduce->all_reduce_id()); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + } + } + + return true; +} + +StatusOr ArCrsCombiner::Run(HloModule* module) { + call_graph_ = CallGraph::Build(module); + + GroupAllReducesById(module); + + KeepProvablyEqualInstructionGroups(); + + return RewriteGraph(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a7ef76ec3b76972d1b2c7fb548cecfb9423160 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Combine an AllReduce and a CrossReplicaSum when they are close to each other +// in the graph, to use an efficient CrossReplicaSum implementation that +// fully utilizes the interconnect bandwidth. +class ArCrsCombiner : public HloModulePass { + public: + ArCrsCombiner(int num_spatial_partitions) + : num_spatial_partitions_(num_spatial_partitions) {} + absl::string_view name() const override { return "ar-crs-combiner"; } + StatusOr Run(HloModule* module) override; + + // Helper method to allow testing of InstructionsComputeSameValue. + static bool TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2); + + private: + // If the passed instruction is a while parameter, and the while body is only + // called by a single while instruction, return the while instruction. + absl::optional WhileFromBodyParameter( + HloInstruction* instruction); + + // Returns a vector of tuple instructions. + // If all instructions that flow to "instruction" are tuples, return them. + // Otherwise, return an empty vector. + std::vector GetAllTuples(HloInstruction* instruction); + + // Checks whether two different elements in the same tuple compute the same + // value. + bool TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs); + + // Returns whether the instructions i1 and i2 can be shown to evaluate to the + // same value. Handling WHILE requires recursion, which may cause us to visit + // the same instruction again. To avoid infinite loops, we pass a cache of + // visited instruction pairs. + bool InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs); + + // Populates all_reduce_map_. + void GroupAllReducesById(HloModule* module); + + // Looks at each AllReduce group in all_reduce_map_, and keeps only the + // groups for which it's safe to move the AllReduce later in the HLO graph. + void KeepProvablyEqualInstructionGroups(); + + // Performs the graph rewrite that eliminates the early AllReduce and turns + // the later CRS into an AllReduce. + StatusOr RewriteGraph(); + + int num_spatial_partitions_; + + // Map from all-reduce ids to the all reduce instructions. + absl::flat_hash_map> all_reduce_map_; + + std::unique_ptr call_graph_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..10171835d83c75fef091a34b8fe102d263211307 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -0,0 +1,496 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ArCrsCombinerTest : public HloTestBase {}; + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue( + i1, module->entry_computation()->parameter_instruction(0))); + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple1 = (f32[2,2]) tuple(%constant.f32) + %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile1) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile2) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile3) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1 + auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2 + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())), + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); + for (int i = 0; i < replica_groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = replica_groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = replica_groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32.1, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32.2, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index f70f6ddfec69c0113a1afe2073a2392098f49456..0e6ca1871b379a2f55b92207133822fc6258b007 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -107,19 +107,37 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } std::unique_ptr Mean( - int64 element_count, HloInstruction* operand, + HloInstruction* element_count, HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* elem_count_recip = - add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1.0 / element_count))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, - operand, elem_count_recip); + auto broadcast = add_instruction( + HloInstruction::CreateBroadcast(operand->shape(), element_count, {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide, + operand, broadcast); + } + + std::unique_ptr DynamicElementCountPerFeature( + HloInstruction* operand, int64 feature_index, + const std::function)>& + add_instruction) { + auto elements_per_feature_u32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (i == feature_index) { + continue; + } + auto dynamic_dimension_size = + add_instruction(HloInstruction::CreateGetDimensionSize( + ShapeUtil::MakeShape(U32, {}), operand, i)); + elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_u32)); + } + + return HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + elements_per_feature_u32); } // Replaces the existing HLO instruction old_instruction, with @@ -195,9 +213,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape operand_shape = operand->shape(); PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); - const int64 feature_count = operand_shape.dimensions(feature_index); - const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -220,6 +235,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } } + auto elements_per_feature = + add(DynamicElementCountPerFeature(operand, feature_index, add)); + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); @@ -243,13 +261,13 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_reduce_computation)); // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum, add)); + auto mean = add(Mean(elements_per_feature, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); + auto square_mean = add(Mean(elements_per_feature, squared_sum, add)); // E^2[X]. auto mean_square = @@ -458,9 +476,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( int64 feature_index = batch_norm->feature_index(); - const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); - const int64 feature_count = activation_shape.dimensions(feature_index); - const int64 elements_per_feature_int64 = size_in_elements / feature_count; + auto elements_per_feature = + add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); @@ -553,15 +570,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add( - Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add)); - auto elements_per_feature_literal = - LiteralUtil::CreateR0(elements_per_feature_int64); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal.Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, add(HloInstruction::CreateBroadcast( activation_shape, elements_per_feature, {}))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 08cf8026177d77ff98cca5e5d168ac3194936b35..8e8fbbd935b154e5a77d68e60d861601d740bf03 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -36,7 +36,21 @@ limitations under the License. namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +class BatchNormExpanderTest : public HloTestBase { + protected: + // BatchNorm should have a dynamic sized dividor for mean operations. + int64 CountGetDimensionSize(const HloModule& module) { + int64 count = 0; + for (HloComputation* comp : module.computations()) { + for (HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == HloOpcode::kGetDimensionSize) { + count++; + } + } + } + return count; + } +}; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -68,6 +82,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -110,6 +125,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 40c012a5e4214f00dbeaca4e8cbfaa668089c6e8..8d7c62447852fd946440c41389300a92377c471f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -746,8 +746,7 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) { - BufferAssigner assigner(allow_input_output_aliasing, - allocate_buffers_for_constants, std::move(colorer), + BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), std::move(reuse_checker)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), @@ -1434,33 +1433,40 @@ BufferAssigner::MergeColocatedBufferSets( computation == module->entry_computation(); }; + std::vector set_can_be_merged(colocated_buffer_sets.size(), true); + + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (auto& buffer : colocated_buffer_sets[i]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { + set_can_be_merged[i] = false; + break; + } + } + } + // Returns true if the two colocated buffer sets (specified by their indices // into the colocated_buffer_sets) can be merged into a single set. auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, - &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs, entry parameters or - // constants. - // - // Buffer liveness does not report the correct live range for entry - // parameter and live out buffers so we have to special case them here. On - // backends that support constant buffer allocations, constant buffers are - // assigned globals in readonly storage so we can't merge colocated buffer - // sets containing constants with colocated buffer sets containing writing - // instructions or other constants. - // - // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to - // the caller of the executable so we can't write to entry parameters - // either, and the argument for not merging constants also applies to entry - // parameters. - for (int64 key : {i, j}) { - for (auto& buffer : colocated_buffer_sets[key]) { - if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kConstant) { - return true; - } - } + &set_can_be_merged](int64 i, int64 j) { + if (!set_can_be_merged[i] || !set_can_be_merged[j]) { + return true; } // Colocated sets satisfy the invariant that all buffers within a set have diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index d8e1612b899f10a5793f9c65c59a41024dfdddd1..0a9fdede803e84ca42472259084615c031b206eb 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -545,12 +545,10 @@ class BufferAssigner { ReuseAllocationFunction reuse_checker = nullptr); private: - BufferAssigner(bool allow_input_output_aliasing, - bool allocate_buffers_for_constants, + BufferAssigner(bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) - : allow_input_output_aliasing_(allow_input_output_aliasing), - allocate_buffers_for_constants_(allocate_buffers_for_constants), + : allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer), reuse_checker_(reuse_checker) {} virtual ~BufferAssigner() = default; @@ -640,10 +638,6 @@ class BufferAssigner { LogicalBuffer::Color::Hasher> SplitBuffersByColor(const absl::flat_hash_set& buffers); - // If true, buffer assignments assumes that input parameter buffers and output - // buffers can be shared if their sizes match. - bool allow_input_output_aliasing_; - // If true, allocate buffers for constant instructions. bool allocate_buffers_for_constants_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index b1fc50cb1881241a0a53b024b06342308cabdd62..8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloTestBase { } std::unique_ptr RunBufferAssignmentWithInstructionSequence( - HloModule* module, - absl::Span instruction_sequence, + HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { HloSchedule schedule(module); schedule.set_sequence(module->entry_computation(), instruction_sequence); @@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, absl::make_unique(schedule), ByteSizeOf, @@ -2162,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // nodes are traversed during BufferAssignment. TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -2391,15 +2390,16 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - schedule.set_sequence(module->entry_computation(), - {input1, weights1, one, output1, while1->operand(0), - while1, input0, weights0, zero, output0, - while0->operand(0), while0, gte0, gte1, root_add}); + schedule.set_sequence( + module->entry_computation(), + {input1, weights1, one, output1, while1->mutable_operand(0), while1, + input0, weights0, zero, output0, while0->mutable_operand(0), while0, + gte0, gte1, root_add}); // If this ASSERT fails, we constructed a bogus sequence above and this test // itself is buggy. diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index aeee543e8435200915ab992e2aa146a3c17646d5..40825a78716b1c0b9fb0121787977d275891c0f8 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -611,8 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - std::unique_ptr BuildModule(const bool update_uses_tuple_element1, - const bool fuse_gte0) { + std::unique_ptr BuildModule( + const bool update_uses_tuple_element1, const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index bdd5069632e84fe6c67ca129f726432479ac1b35..7987343bfaf1069fd550909d127e4b11f2124701 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -325,6 +325,15 @@ bool CallGraph::IsFlattened() const { return true; } +std::vector CallGraph::GetComputationCallers( + HloComputation* c) { + std::vector callers; + for (auto callsite : GetNode(c).caller_callsites()) { + callers.push_back(callsite.instruction()); + } + return callers; +} + std::pair CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, HloInstruction* b) const { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index cb56f4789d06ac33acdaadc8b619b9e37f683d58..05c7c998738f861ee804d1ec87bfa5fb17ddfb74 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -236,6 +236,10 @@ class CallGraph { // FlattenCallGraph. bool IsFlattened() const; + // Returns a vector of instructions calling the passed computation. + // (Often a vector of size 1.) + std::vector GetComputationCallers(HloComputation* c); + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 67132274c0dcbfda831c79836d052bb51b753ec7..1965925fa7f6d50b1d7af918bc3468d4b4d5d0a2 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.host_program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = - *instance.result_layout; + instance.result_layout->ToProto(); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(program_shape, instance.argument_layouts, - &execution_options)); + CreateModuleConfig( + ProgramShape(instance.computation.host_program_shape()), + instance.argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index c899ffb9dc562426ef14c0d414469c04debeec70..844b42a38d7539cccd5c4e30071c0ea6693e3bba 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -105,8 +105,6 @@ class ComputationPlacer { // Map from platform kind to computation placer singleton. static std::map* GetPlatformComputationPlacers(); - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 7f7f1503a099b3a67ed22cb5978c01da6cf8ba88..95c7724c3c93507ae61a984301ecfc0111bef192 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -142,16 +142,16 @@ std::vector GetMaskIds(int64 group_size, int64 group_count) { // Finally we use the Eq op of these two broadcasted constants and get the // desired mask. HloInstruction* GetExpandedFilterMask( - const Shape& filter_shape, int64 input_feature_dim, - int64 output_feature_dim, int64 group_count, + const Shape& filter_shape, int64 kernel_input_feature_dim, + int64 kernel_output_feature_dim, int64 group_count, const std::function)>& add_instruction) { Shape expanded_filter_shape = - ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim); Shape mask_shape = ShapeUtil::MakeShape( S32, AsInt64Slice(expanded_filter_shape.dimensions())); - int64 output_feature = filter_shape.dimensions(output_feature_dim); - int64 group_size = filter_shape.dimensions(input_feature_dim); + int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim); + int64 group_size = filter_shape.dimensions(kernel_input_feature_dim); // Create a 'input_feature' sized linspace and 'output_feature' sized linspace // that will be broadcasted into perpendicular dimensions and compared. @@ -159,15 +159,14 @@ HloInstruction* GetExpandedFilterMask( GetMaskIds(group_size, group_count); const std::vector output_feature_filter_mask = GetMaskIds(output_feature / group_count, group_count); - auto mask1 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(input_feature_filter_mask))); - auto broadcasted_mask1 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask1, {kernel_input_feature_dim})); auto mask2 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(output_feature_filter_mask))); - auto broadcasted_mask2 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask2, {kernel_output_feature_dim})); // Compare the broadcasted output feature linspace to the input feature // linspace to create a diagonal predicate. @@ -189,91 +188,203 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { }; auto dim_numbers = convolution->convolution_dimension_numbers(); - int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); - int64 group_size = filter->shape().dimensions(input_feature_dim); - int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); - auto expanded_filter_shape = - ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); - HloInstruction* filter_mask = GetExpandedFilterMask( - filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); + int64 kernel_output_feature_dim = + dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count, + kernel_input_feature_dim); + HloInstruction* filter_mask = + GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim, + kernel_output_feature_dim, group_count, add); HloInstruction* expanded_filter; if (group_size == 1) { bool depthwise_separable = - (group_count == filter->shape().dimensions(output_feature_dim)); + (group_count == filter->shape().dimensions(kernel_output_feature_dim)); // If the code generator handles depthwise separable convolutions // inherently, then no filter expansion is needed. if (!filter_expansion_ && depthwise_separable) { - const int64 old_kernel_input_feature_dimension = - dim_numbers.kernel_input_feature_dimension(); - const int64 old_kernel_output_feature_dimension = - dim_numbers.kernel_output_feature_dimension(); - - // For depthwise convolutions, we want the kernel input feature dimension - // to be smaller than the output feature dimension. If that's not the - // case, we swap the dimensions. - if (old_kernel_input_feature_dimension > - old_kernel_output_feature_dimension) { - Shape reshaped_filter_shape = filter->shape(); - auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); - std::swap(dimensions[old_kernel_input_feature_dimension], - dimensions[old_kernel_output_feature_dimension]); - - auto reshaped_filter = - add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); - - dim_numbers.set_kernel_input_feature_dimension( - old_kernel_output_feature_dimension); - - dim_numbers.set_kernel_output_feature_dimension( - old_kernel_input_feature_dimension); - - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), - reshaped_filter, group_count, convolution->window(), dim_numbers, - convolution->precision_config()); - - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); - } return Status::OK(); } // We want to repeat 'filter' in the 'input_feature_dim' dimension // 'group_count' times. Shape reshaped_filter_shape = - ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape()); auto reshaped_filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); std::vector broadcast_dims; for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { - if (i == input_feature_dim) { + if (i == kernel_input_feature_dim) { continue; } broadcast_dims.push_back(i); } expanded_filter = add(HloInstruction::CreateBroadcast( expanded_filter_shape, reshaped_filter, broadcast_dims)); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter, + zero_filter)); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); } else { - // We could possibly also use reshape, broadcast, reshape instead of concat - // here, but it would require more complex code, and for depthwise - // convolution we would never end up in this branch. - std::vector concat_operands(group_count, filter); - expanded_filter = add(HloInstruction::CreateConcatenate( - expanded_filter_shape, concat_operands, input_feature_dim)); + int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + + // If group_count == output_feature, then we map those grouped convolutions + // onto depthwise convolution. This is done by adding an additional spatial + // dimension to the activations, kernel, and the output. + // E.g., we would turn + // [2, 12]{B, IF} conv [3, 4]{IF, OF} into + // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the + // additional spatial dimension. The generated convolution output will be + // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. + + if (group_count == output_feature && !filter_expansion_) { + auto filter = convolution->mutable_operand(1); + auto activation = convolution->mutable_operand(0); + + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); + + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; + + reshaped_activation_shape.set_dimensions(activation_input_feature_dim, + group_count); + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Add spatial dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); + + filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + dim_numbers.add_output_spatial_dimensions(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, group_count, new_window, + dim_numbers, convolution->precision_config())); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + } else { + // The filter expansion mechanism adds zeroes in the kernel. + // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask + // would look like (IF on the Y-axis, OF on the X-axis) + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // + // Instead of convolving the above with the input, we instead slice the + // kernel into three kernels, each containing islands of 1s from the + // filter above. We also slice the activations in the IF dimension with + // each slice of size = group_size. For each slice, we perform + // convolutions, and concatenate the generated outputs in the output OF + // dimension. + + std::vector sliced_convolutions; + auto activation = convolution->mutable_operand(0); + std::vector slice_strides(filter->shape().dimensions_size(), 1); + std::vector filter_slice_starts(filter->shape().dimensions_size(), + 0); + std::vector filter_slice_limits( + filter->shape().dimensions().begin(), + filter->shape().dimensions().end()); + std::vector activation_slice_starts( + activation->shape().dimensions_size(), 0); + std::vector activation_slice_limits( + activation->shape().dimensions().begin(), + activation->shape().dimensions().end()); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + auto output_feature_dim = dim_numbers.output_feature_dimension(); + int64 filter_slice_width = output_feature / group_count; + + int64 activation_input_feature_dim = + dim_numbers.input_feature_dimension(); + + for (int64 i = 0; i < group_count; i++) { + filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width; + filter_slice_limits[kernel_output_feature_dim] = + (i + 1) * filter_slice_width; + auto filter_sliced_shape = filter->shape(); + filter_sliced_shape.set_dimensions(kernel_output_feature_dim, + filter_slice_width); + auto filter_slice = add(HloInstruction::CreateSlice( + filter_sliced_shape, filter, filter_slice_starts, + filter_slice_limits, slice_strides)); + + activation_slice_starts[activation_input_feature_dim] = i * group_size; + activation_slice_limits[activation_input_feature_dim] = + (i + 1) * group_size; + auto activation_sliced_shape = activation->shape(); + activation_sliced_shape.set_dimensions(activation_input_feature_dim, + group_size); + auto activation_slice = add(HloInstruction::CreateSlice( + activation_sliced_shape, activation, activation_slice_starts, + activation_slice_limits, slice_strides)); + + auto conv_slice_shape = convolution->shape(); + conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width); + + auto new_convolution = add(HloInstruction::CreateConvolve( + conv_slice_shape, activation_slice, filter_slice, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config())); + + sliced_convolutions.push_back(new_convolution); + } + + auto new_conv = HloInstruction::CreateConcatenate( + convolution->shape(), sliced_convolutions, output_feature_dim); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_conv))); + } } - auto zero = add(HloInstruction::CreateConstant( - LiteralUtil::Zero(expanded_filter_shape.element_type()))); - auto zero_filter = - add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); - auto new_filter = add( - HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, - filter_mask, expanded_filter, zero_filter)); - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc index 28373ebf636c7b6b3059dcf6cd931901ebc87fc2..e6bf2143a21bd5001d3530fe8727c88504be1d43 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -82,18 +82,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 ConvolutionFeatureGroupConverter converter; ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - // Make sure the convolution is converted to one with feature_group_count = 1. - EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - EXPECT_EQ(root->feature_group_count(), 1); - // Verify that the filter operand has been replaced. - EXPECT_THAT(root->operand(1), - op::Select(op::Eq(op::Broadcast(op::Constant()), - op::Broadcast(op::Constant())), - // We expect to see Concatenate here instead of - // Broadcast, because feature_group_count < input - // feature dimension. - op::Concatenate(op::Parameter(), op::Parameter()), - op::Broadcast(op::Constant()))); + // Make sure the convolution is replaced with a concatenate. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + // And the operands of the concatenate are convolutions, each with a feature + // group count = 1. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(0)->feature_group_count(), 1); + EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } } // namespace diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 4e547d925f62dce1d2dd23a39a28ca8c23ba9f2f..df6059663876dfde71f4c75d3931b3d2de72c1df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -442,7 +442,6 @@ class CopyRemover { const HloOrdering& ordering, HloModule* module) : module_(module), alias_analysis_(alias_analysis), - ordering_(ordering), buffer_value_tracker_(*module, alias_analysis, ordering) {} // Try to elide the given copy. The copy is elided if the instruction is not @@ -1003,7 +1002,6 @@ class CopyRemover { HloModule* module_; const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; // Object tracking the HLO values contained in each HLO buffer. BufferValueTracker buffer_value_tracker_; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 7446bc7cc11553984dcf1cea00c58072d2cbf0f0..e4e9d7ba05c115be9dd0eb53ebd7de208d514efb 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1362,7 +1362,7 @@ TEST_F(CopyInsertionTest, CrossingParameters) { // | / \ | // | / \| // (p1 , p0) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1395,7 +1395,7 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1428,7 +1428,7 @@ TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1461,7 +1461,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1496,7 +1496,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { // | | | // | | | // +-- (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1534,7 +1534,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { // | Add----+ // | | | // +-- (p0 , p1) - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1569,7 +1569,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1632,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1693,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1783,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2763d18121a0c1328ea0c11d825476923ae2b15d..ce4c2a9cc69240b9565b35a3f2504d7fc9373917 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -96,6 +96,7 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 73b03440cbb936017257b8a92f16dcc25d41e21c..796a7cf94d02b0ad42366387a9d3f8d589b8840a 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -61,19 +61,6 @@ Disabling these as a starting point. // TODO(b/64227304) Creating a custom pass pipeline will replace this. namespace { -class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { - public: - FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) - : llvm::legacy::FunctionPassManager(m), - disable_expensive_passes_(disable_expensive_passes) {} - void add(llvm::Pass* p) override { - llvm::legacy::FunctionPassManager::add(p); - } - - private: - bool disable_expensive_passes_; -}; - class FilteredPassManager : public llvm::legacy::PassManager { public: explicit FilteredPassManager(bool disable_expensive_passes) @@ -96,8 +83,7 @@ class FilteredPassManager : public llvm::legacy::PassManager { std::unique_ptr CompilerFunctor::operator()( llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); - FilteredFunctionPassManager function_passes(&module, - disable_expensive_passes_); + llvm::legacy::FunctionPassManager function_passes(&module); VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 4ce5a8a29255a763c83941efb6de9b7c652cedb4..6374822c81bf42fd12829f57cf93c19457128219 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -76,6 +76,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -268,10 +269,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - pass.AddPass( - /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }, - /*enable_dot_strength_reduction=*/false); + pipeline.AddPass(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_enable_dot_strength_reduction(false); + pass.AddPass(options); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -334,10 +336,11 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - pass.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_dot_strength_reduction(false); + pass.AddPass>(options); pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } @@ -587,9 +590,9 @@ StatusOr> CpuCompiler::RunBackend( // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module.get(), BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( @@ -779,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction())); + ScheduleModule(module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 29abf38e439d919ff93629ed992cb3ff93a929bd..818b2b0d0db2893e11fa46c7867e6c74bbbb6905 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -51,8 +51,7 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3c3c047bfe8ee0d1ad90ede2432a86264f47870b..3b91b15ba9b5603b50f78f489e9a3fdad354c083 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -49,7 +49,7 @@ class CpuExecutable : public Executable { public: CpuExecutable(std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f9cd61bea3dc86cadff99d4a90eca44c16520823..6f79ad7c1468f27c74d84770ec6358fbcd1c1f09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) { (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } +bool HasExactlyOneUse(const HloInstruction& hlo_instr) { + return hlo_instr.user_count() == 1 && + absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; +} + bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && - producer->user_count() == 1; + HasExactlyOneUse(*producer) == 1; } bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index c95a514ca04bee1fb4c03ee21510eb8da3122081..527df0bd1c23bba74f32226e5622fed32f7dcf84 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); @@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + struct GatherLoopFusionTestSpec { string test_name; string hlo_computation_text; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 2cd52e4a18a4524365393db5f658a982d83a7632..6c61b64758ede160e2d50e4429590a789ec253c3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b8ace5702688096822573c7afae234cbcbe77b28..92debb83e33b1400a59e5eef0f90971392ab7b22 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,7 +22,6 @@ limitations under the License. namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; -const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 620c45fa391e69ef88269d44709404e6f71b30cb..4032c2da2f33ee61da8771ae6225a14172cbe6e8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,7 +111,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -140,7 +140,7 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { @@ -1379,33 +1379,6 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -// Fills up the free variables in 'index_with_free_var' with values from -// 'filler_index'. The size of free variables must be the same as the -// size of 'filler_index'. -// -// This is often used after dimension reduction, where -// 'index_with_free_var' has one or more dimensions reduced, which serves as -// free variables (represented as nullptr). For example, if we have a 4 -// dimensional input and index for the dimension being reduced is -// 2 (third dimension), we will have an index like [i, j, NULL, k] -// after reduced dimension. -// -// Here we fill up that free variable by 'filler_index', which contains -// the value in the reduced dimension. -static llvm_ir::IrArray::Index FillReducedDimensionIndex( - llvm_ir::IrArray::Index index_with_free_var, - llvm_ir::IrArray::Index filler_index) { - llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); - - for (size_t i = 0; i < index_with_free_var.size(); ++i) { - if (index_with_free_var[i] == nullptr) { - index_with_free_var[i] = *it++; - } - } - CHECK(filler_index.end() == it); - return index_with_free_var; -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); return EmitTargetAddressForOp(parameter); @@ -2194,14 +2167,6 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { return Status::OK(); } -// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. -static const HloInstruction* StripTranspose(const HloInstruction& hlo) { - if (hlo.IsRank2Transpose()) { - return hlo.operand(0); - } - return &hlo; -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2600,10 +2565,17 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } -Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { - TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); +Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { + TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all)); + return Status::OK(); +} + +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + emitted_value_[add_dependency] = + GetEmittedValueFor(add_dependency->operand(0)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 136b88ff75ea8a5f48b42d3476219f18f5ecb39a..559a8162a2d53f28ea6817653503c216af90a610 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -159,7 +159,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -467,9 +468,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // profiling a computation. class ProfilingState { public: - ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} - ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) - : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false) {} + explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); @@ -494,9 +494,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // intrinsic? bool use_rdtscp_; - // The argument which corresponds to the profile counter buffer. - llvm::Value* prof_counters_; - // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 669eeb95f3299623a7556bfbb8045fd77f5d0745..722aa3120ef4d8c957873ac58c361f19632dde1f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -41,61 +42,60 @@ void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { std::sort(row_to_sort, row_to_sort + num_elements); } -// For floating point numbers, we want a total order comparator. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. Also we want to have a stable sort, so if the keys are the -// same, we compare the index values. -template -bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { - bool lhs_is_negative = std::signbit(lhs); - bool rhs_is_negative = std::signbit(rhs); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(lhs); - bool rhs_nan = std::isnan(rhs); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; +// We would like a total order of floating point numbers so that the +// sort has a predictable behavior in the presence of NaNs. Rather +// than using floating point comparison, we use the following trick: +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? 0x7FFFFFFF - x : x; +// then y is ordered as an int32 such that finite values have the +// obvious order, -0 is ordered before 0, and -NaN and NaN appear at +// the beginning and end of the ordering. +template +CastType Convert(KeyType value) { + CastType casted_value; + memcpy(&casted_value, &value, sizeof(CastType)); + if (casted_value < 0) { + return static_cast(std::numeric_limits::max()) - + casted_value; } - if (lhs != rhs) { - return lhs < rhs; - } - return lhs_index < rhs_index; + return casted_value; +} + +template +bool LessThan(KeyType lhs, KeyType rhs) { + return Convert(lhs) < + Convert(rhs); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), lhs.second, - Eigen::half_impl::half_to_float(rhs.first), rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), + Eigen::half_impl::half_to_float(rhs.first)); + }); } template diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index f77641eb7da71117092730c1fd5090c61c939813..efccadedf27181a4cddf4f1dc3610f7c6db1d821 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -128,8 +128,18 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + void* func_addr = nullptr; + if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { + // On Mac OS X, 'name' may have a leading underscore prefix, even though the + // registered name may not. + std::string stripped_name(name.begin() + 1, name.end()); + func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name); + } else { + func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + } + if (func_addr == nullptr) { + VLOG(2) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 691b3c7bee26e84edbef18a4ac10a9cafd29c61a..f8f5f392da8ab3348e63185aecf7b639daacaa42 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index d201a151d7a9edb86a0de15819ea99f95a9c4d28..e30f95311fce229f9c559d3bb40142151e8bf3e3 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewUnverifiedModule(); + std::unique_ptr module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 773336c7a92f808f0c6370c7353e780b1471470f..9b10c49f4f547edfb2164f98c49cceb031148bdc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 3b87683ffffefd2aa24dd234cc072425bef00a24..fa0e09ff6b5694c0e97963b83c6e541b858a1376 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -63,7 +63,7 @@ CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -104,14 +104,14 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [4 x i8] -CHECK: private constant [8 x i8] +CHECK-DAG: private constant [4 x i8] +CHECK-DAG: private constant [8 x i8] CHECK-NOT: private constant [4 x i8] CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index f5419b7063bea6d1f5d24fde0a22e829413b8d93..a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 990ff94ba2338cb663b655ca3106bda83ab718a3..70008947f371d25e95d02839c30ba822fce7a292 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index d6371283221b63b30f968929fe2807eae3f22df0..e84bf00153aa28df29d8df486b92654feab4afbf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -251,6 +251,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index e57184f639f4f2c618b980a5082381f4b9c28b19..80ea5be298aea44a0f424398da74c4e478f10346 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -206,6 +206,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGetDimensionSize(HloInstructionPtr get_size) override { return DefaultAction(get_size); } + Status HandleAddDependency(HloInstructionPtr add_dependency) override { + return DefaultAction(add_dependency); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d0472689bf48092ceef2e9792c1358687d707ec --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -0,0 +1,459 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +namespace { +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} +} // namespace + +class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { + public: + explicit DynamicDimensionInferenceVisitor( + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) + : param_bindings_(param_bindings), parent_(parent) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation, + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) { + DynamicDimensionInferenceVisitor visitor(param_bindings, parent); + return computation->Accept(&visitor); + } + + Status HandleParameter(HloInstruction* hlo) override; + + Status HandleReduce(HloInstruction* hlo) override; + + Status HandleDot(HloInstruction* hlo) override; + + Status HandleTranspose(HloInstruction* hlo) override; + + Status HandleReshape(HloInstruction* hlo) override; + + Status HandlePad(HloInstruction* hlo) override; + + Status HandleBroadcast(HloInstruction* hlo) override; + + Status HandleGetDimensionSize(HloInstruction* hlo) override; + + Status HandleSelect(HloInstruction* hlo) override; + + Status HandleConvolution(HloInstruction* hlo) override; + + Status HandleReduceWindow(HloInstruction* hlo) override; + + Status HandleSelectAndScatter(HloInstruction* hlo) override; + + Status HandleGetTupleElement(HloInstruction* hlo) override; + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + + Status HandleElementwiseBinary(HloInstruction* hlo) override; + + private: + using OperandDynamicDimensionFn = std::function; + + Status ForEachOperandDynamicDimension(HloInstruction* inst, + const OperandDynamicDimensionFn&); + + // Pass through a dynamic dimension from the input to the output with the same + // value and index in the shape. This is a helper function to handle trivial + // instructions like elementwise operations. + Status PassThroughDynamicDimension(HloInstruction*); + + // The dynamic parameter bindings of this computation. + const DynamicParameterBinding& param_bindings_; + + // A pointer to DynamicDimensionInference, used to update the dynamic mapping. + DynamicDimensionInference* parent_; +}; + +Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + return UnimplementedStrCat( + "Asked to propagate a dynamic dimension from hlo ", + operand->ToString(), "@", index.ToString(), "@", dimension, + " to hlo ", hlo->ToString(), ", which is not implemented."); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (hlo->tuple_index() == index[0]) { + ShapeIndex new_index = + ShapeIndexView(index).ConsumeFront().ToShapeIndex(); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); + } + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + int64 broadcast_dim = hlo->dimensions(dimension); + parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (operand_index != 0) { + return Unimplemented( + "Dynamic dimension on padding value is not supported"); + } + const PaddingConfig_PaddingConfigDimension& padding_config = + hlo->padding_config().dimensions(dimension); + if (padding_config.interior_padding() == 0 && + padding_config.edge_padding_low() == 0 && + padding_config.edge_padding_high() == 0) { + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + return Status::OK(); + } else { + return Unimplemented( + "Dynamic dimension propagation on padding dimension is not " + "supported."); + } + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce = hlo; + int64 operand_count = reduce->operand_count(); + CHECK_EQ(operand_count % 2, 0); + if (operand_index >= operand_count / 2) { + // Init values doesn't have dynamic size. + return Status::OK(); + } + if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { + // Dimension is to be reduce, stop tracing. + return Status::OK(); + } + + // Find out the new dynamic dimension after reduce. + int64 dimensions_not_reduced_count = 0; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (dimension == i) { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + + return Status::OK(); + } + if (absl::c_count(reduce->dimensions(), i) == 0) { + dimensions_not_reduced_count++; + } + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = + dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; + std::unordered_set batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.lhs_contracting_dimensions(), i)) { + if (operand_index == 0) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.rhs_contracting_dimensions(), i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), + i)) { + if (operand_index == 1) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], + dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleConvolution( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* conv = hlo; + const ConvolutionDimensionNumbers& dimension_numbers = + conv->convolution_dimension_numbers(); + + if (operand_index == 0) { + if (dimension == dimension_numbers.input_batch_dimension()) { + parent_->SetDynamicSize(conv, {}, + dimension_numbers.output_batch_dimension(), + dynamic_size); + return Status::OK(); + } + + if (dimension == dimension_numbers.input_feature_dimension()) { + return Status::OK(); + } + } else { + if (dimension == dimension_numbers.kernel_input_feature_dimension()) { + return Status::OK(); + } + } + + return Unimplemented("Dynamic Spatial Convolution is not supported: %s", + conv->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( + HloInstruction*) { + // Dynamic dimension doesn't propagate through GetDimensionSize: + // + // Input: F32[x, y, z] + // | + // GetDimensionSize(1): U32[] + // + // The returned value is a scalar, which doesn't have any dynamic dimension in + // the shape (although the value contains the real size of the dynamic + // dimension of the input). + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reshape = hlo; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(), + reshape->shape()); + for (auto& unmodified : unmodified_dims) { + if (unmodified.first == dimension) { + parent_->SetDynamicSize(reshape, {}, unmodified.second, + dynamic_size); + return Status::OK(); + } + } + return Unimplemented( + "Dynamic Reshape on modified dimensions is yet not supported: %s", + reshape->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduceWindow( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce_window = hlo; + const WindowDimension& window_dimension = + reduce_window->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial reduce window is not supported: %s", + reduce_window->ToString()); + } + + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* select_and_scatter = hlo; + const WindowDimension& window_dimension = + select_and_scatter->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial select and scatter is not supported: %s", + select_and_scatter->ToString()); + } + + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { + return param_bindings_.ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) { + if (dynamic_dimension.parameter_num != hlo->parameter_number()) { + return Status::OK(); + } + HloComputation* computation = hlo->parent(); + HloInstruction* target_parameter = + computation->parameter_instruction(dynamic_dimension.parameter_num); + + HloInstruction* dynamic_size = + computation->parameter_instruction(dynamic_parameter.parameter_num); + for (int64 i : dynamic_parameter.parameter_index) { + dynamic_size = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(dynamic_size->shape(), {i}), + dynamic_size, i)); + } + + parent_->SetDynamicSize(target_parameter, + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( + HloInstruction* inst, const OperandDynamicDimensionFn& fn) { + for (int64 operand_index = 0; operand_index < inst->operand_count(); + ++operand_index) { + auto iter = + parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index)); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, + dynamic_size)); + } + } + } + return Status::OK(); +} + +/* static */ +StatusOr DynamicDimensionInference::Run( + HloModule* module) { + VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + DynamicDimensionInference inference(module); + TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + return inference; +} + +DynamicDimensionInference::DynamicDimensionInference(HloModule* module) + : module_(module) {} + +Status DynamicDimensionInference::AnalyzeDynamicDimensions() { + return DynamicDimensionInferenceVisitor::Run( + module_->entry_computation(), module_->dynamic_parameter_binding(), this); +} + +HloInstruction* DynamicDimensionInference::GetDynamicSize( + HloInstruction* inst, const ShapeIndex& index, int64 dim) const { + auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); + if (iter != dynamic_mapping_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..164d15bf111a92e3da957f609b54ee0662ef18b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// DynamicDimensionInference analyzes each HLO instruction in a graph and +// inferences which dimensions are dynamic and which scalar instructions +// represent the runtime real size of those dynamic dimensions. +class DynamicDimensionInference { + public: + static StatusOr Run(HloModule* module); + + string ToString() const; + + // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, + // returns a scalar HloInstruction that represents the runtime size of that + // dimension. Otherwise returns nullptr. + HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, + int64 dim) const; + + friend class DynamicDimensionInferenceVisitor; + + private: + explicit DynamicDimensionInference(HloModule* module); + + // DynamicDimension is used as a key in the dynamic key-value mapping. It + // unambiguously represents a dynamic dimension of a instruction at a given + // index. + struct DynamicDimension { + // HloInstruction that holds the dimension. + HloInstruction* inst; + // Subshape of the instruction that holds the dimension. + ShapeIndex index; + // The dimension number of the dynamic dimension at given index of a given + // instruction. + int64 dim; + + // Artifacts needed to make this struct able to be used as a `key` in absl + // maps. "friend" keywords are added so these functions can be found through + // ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.inst, m.index, m.dim); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.inst == rhs.inst && lhs.index == rhs.index && + lhs.dim == rhs.dim; + } + }; + + // Update the dynamic mapping so that we know dimension `dim` of instruction + // `inst` at `index` has a dynamic size, and its runtime size is represented + // by a scalar instruction `size`. + void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, + HloInstruction* size) { + dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(DynamicDimension{inst, index, dim}); + } + + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in + // module_. + Status AnalyzeDynamicDimensions(); + + // HloModule being analyzed. + HloModule* module_; + + // dynamic_mapping_ holds the result of the analysis. It maps a dynamic + // dimension to a scalar HloInstruction that represents the real dynamic size + // of the dynamic dimension. + using DynamicMapping = absl::flat_hash_map; + DynamicMapping dynamic_mapping_; + + using PerHloDynamicDimensions = + absl::flat_hash_map>; + PerHloDynamicDimensions per_hlo_dynamic_dimensions_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea9ebed45d99797ce4f80376ec3d0b758da3ca17 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -0,0 +1,535 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicDimensionInferenceTest : public HloTestBase { + protected: + DynamicDimensionInferenceTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); + } + + Status RunInference() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module_.get())); + + inference_ = absl::make_unique(inference); + return Status::OK(); + } + + HloComputation* GetAdd() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + std::unique_ptr inference_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); +}; + +TEST_F(DynamicDimensionInferenceTest, ParamTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "param")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { + // When data flows through GTE, the dynamic dimension size keeps the + // same, and the index has its front popped. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, param, 0)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { + // When data flows through elementwise, the dynamic dimension size keeps the + // same. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { + // Same as ReduceTestI, but only reduce one dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction( + HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, DotTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, TransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 3})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + Status status = RunInference(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { + // Test the ability to trace broadcast dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(output_shape, a_param, {1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { + // Test the ability to trace reduce window batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { + // Test the ability to trace select and scatter batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8bfc8905064bcd7b68fe259fbcc1546ff083dbd --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +Status DynamicParameterBinding::Bind( + const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) { + auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter); + TF_RET_CHECK(result.second); + return Status::OK(); +} + +absl::optional +DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { + auto param_iter = bindings_.find(dynamic_dimension); + if (param_iter == bindings_.end()) { + return absl::nullopt; + } + return param_iter->second; +} + +DynamicParameterBindingProto DynamicParameterBinding::ToProto() const { + DynamicParameterBindingProto result; + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + DynamicParameterBindingProto::Binding binding_proto; + binding_proto.set_dynamic_param_num(dynamic_param.parameter_num); + for (int64 i : dynamic_param.parameter_index) { + binding_proto.add_dynamic_param_index(i); + } + + binding_proto.set_target_param_num(dynamic_dimension.parameter_num); + + for (int64 i : dynamic_dimension.parameter_index) { + binding_proto.add_target_param_index(i); + } + + binding_proto.set_target_param_dim_num(dynamic_dimension.dimension); + result.add_entries()->Swap(&binding_proto); + } + return result; +} + +StatusOr DynamicParameterBinding::CreateFromProto( + const DynamicParameterBindingProto& proto) { + DynamicParameterBinding result; + for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) { + int64 dynamic_param_num = binding.dynamic_param_num(); + ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(), + binding.dynamic_param_index().end()); + int64 target_param_num = binding.target_param_num(); + ShapeIndex target_param_index(binding.target_param_index().begin(), + binding.target_param_index().end()); + int64 target_dim_num = binding.target_param_num(); + + TF_RETURN_IF_ERROR( + result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, + DynamicDimension{target_param_num, target_param_index, + target_dim_num})); + } + + return result; +} + +string DynamicParameterBinding::ToString() const { + std::vector pieces; + pieces.push_back("DynamicParameterBinding: "); + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + pieces.push_back(absl::StrFormat( + " -- Input param number %lld at %s has dim %lld as dynamic" + " dimension, which is represented by param number %lld at " + "%s", + dynamic_dimension.parameter_num, + dynamic_dimension.parameter_index.ToString(), + dynamic_dimension.dimension, dynamic_param.parameter_num, + dynamic_param.parameter_index.ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + +Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const { + for (const auto& binding : bindings_) { + TF_RETURN_IF_ERROR(fn(binding.second, binding.first)); + } + return Status::OK(); +} + +Status DynamicParameterBinding::Verify(const HloModule& module) const { + const HloComputation* entry = module.entry_computation(); + return ForEachBinding([&](const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) + -> Status { + TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), + dynamic_parameter.parameter_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(), + dynamic_dimension.parameter_index)); + TF_RET_CHECK( + dynamic_dimension.dimension < + ShapeUtil::Rank(ShapeUtil::GetSubshape( + entry->parameter_instruction(dynamic_dimension.parameter_num) + ->shape(), + dynamic_dimension.parameter_index))); + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding) { + out << binding.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..dd474d8eed1b2c30ddb8f624a864198c74eacaba --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; +// We currently use an explicit API that takes an extra parameter to indicate +// the runtime size of a dynamic dimension. DynamicParameterBinding indicates +// the relationship between parameter: We can have a dynamic parameter that +// points to another target parameter to indicate that the target parameter is +// dynamic. +// +// +// TODO(b/119520625): Remove this API once we have more dynamic shape infra +// ready. +class DynamicParameterBinding { + public: + // DynamicParameter represents a special parameter that is used to represent + // the runtime size of a dimension of another parameter. A dynamic parameter + // has to be a scalar value. + struct DynamicParameter { + // The parameter number of dynamic parameter. + int64 parameter_num; + // The index of the parameter. + ShapeIndex parameter_index; + }; + + // DynamicDimension represents a dimension whose size is determined at + // runtime. A DynamicDimension's runtime size is determined by the binded + // DynamicParameter using `DynamicParameterBinding::Bind` method. + struct DynamicDimension { + // The parameter number of dynamic dimension. + int64 parameter_num; + // The subshape index of the parameter. + ShapeIndex parameter_index; + // The dimension number in the subshape. + int64 dimension; + + // "friend" keyword are added so these functions can be found by ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.parameter_num, m.parameter_index, + m.dimension); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.parameter_num == rhs.parameter_num && + lhs.parameter_index == rhs.parameter_index && + lhs.dimension == rhs.dimension; + } + }; + + DynamicParameterBinding() = default; + + virtual ~DynamicParameterBinding() = default; + + // Adds binding which indicates that the dimension indicated by + // `dynamic_dimension` is dynamic, and its runtime size is represented by + // `dynamic_parameter`. + Status Bind(const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension); + + // Returns the parameter and the index representing the runtime size of + // dimension `dim_num` of parameter `param_num` at `param_index`. + // + // Returns nullopt if the binding is not set. + absl::optional GetBinding( + const DynamicDimension& dynamic_dimension); + + using BindingFn = + std::function; + + // Iterate through each binding. + Status ForEachBinding(BindingFn fn) const; + + DynamicParameterBindingProto ToProto() const; + + static StatusOr CreateFromProto( + const DynamicParameterBindingProto& proto); + + string ToString() const; + + // Verifies that the given binding is valid for the given module. + // Specifically, the binding's parameter and parameter size should be valid. + Status Verify(const HloModule& module) const; + + private: + // Keeps track of mappings from DynamicDimension to DynamicParameter. The + // direction of is chosen so that we can easily query if a dimension is + // dynamic and which dynamic parameter represents the real size of that + // dimension. + absl::flat_hash_map bindings_; +}; + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..83a6d83dffde7995bd8e43917d13c5fd2705ba6f --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class DynamicParameterBindingTest : public HloTestBase {}; + +TEST_F(DynamicParameterBindingTest, SimpleBinding) { + // 'b' is a dynamic shape; 'a' represents the real size of b's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[10] parameter(1) + ROOT root = (f32[], f32[10]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBinding) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both + // dimensions. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10, 10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10, 10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 1})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + + TF_EXPECT_OK(binding.Verify(*module)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index f98c943669be8c14d245896b91cee3eee1e47429..6f1f95f2e9082649b6ca9cc0da5c238e15b77c10 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -1671,26 +1672,66 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( b_->SetInsertPoint(init_block); + // Assign a unique id for each *different* operand, and count how often each + // operand is used. If all operands are different, the usage count will be 1 + // for each operand. + absl::flat_hash_map to_unique_operand_id; + std::vector operand_usage_count; + for (const auto* operand : hlo->operands()) { + if (to_unique_operand_id.contains(operand)) { + ++operand_usage_count[to_unique_operand_id[operand]]; + } else { + int64 unique_operand_id = to_unique_operand_id.size(); + to_unique_operand_id[operand] = unique_operand_id; + operand_usage_count.push_back(1); + } + } + + // To avoid that we emit the same operand more than once, we create one basic + // block for each *different* operand with a PHI node for the different source + // index inputs. + std::vector emit_operand_blocks( + to_unique_operand_id.size(), nullptr); + std::vector source_index_phis(to_unique_operand_id.size(), + nullptr); + for (const auto* operand : hlo->operands()) { + int64 operand_id = to_unique_operand_id[operand]; + if (emit_operand_blocks[operand_id] != nullptr) { + continue; + } + + emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand_id", operand_id), b_); + auto saved_insert_point = b_->GetInsertPoint(); + llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); + source_index_phis[operand_id] = + PHI(source_index.GetType(), operand_usage_count[operand_id]); + auto operand_index = source_index; + operand_index[concat_dim] = source_index_phis[operand_id]; + + // Create the terminator of the block before calling operand generators, + // because they require non-degenerate basic blocks. + b_->SetInsertPoint(llvm::BranchInst::Create( + exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(operand_index)); + output->addIncoming(value, b_->GetInsertBlock()); + b_->SetInsertPoint(init_block, saved_insert_point); + } + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, - false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - b_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, b_->GetInsertBlock()); + int64 operand_id = to_unique_operand_id[operand]; + source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + b_->GetInsertBlock()); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. @@ -2204,13 +2245,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( : iota->shape(); PrimitiveType component_element_type = component_shape.element_type(); llvm::Value* iota_result; - if (ShapeUtil::ElementIsIntegral(component_shape)) { + if (primitive_util::IsIntegralType(component_element_type) || + component_element_type == PRED) { iota_result = b_->CreateIntCast( elem_index_linear, llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), /*isSigned=*/false); } else { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + TF_RET_CHECK( + primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; if (component_element_type == BF16) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 45f620f3f33eee41eefa9ddfdfb166a5ba76caef..b34bca55a48b113c325dbf28c03f7a0f5b71f658 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -61,7 +61,7 @@ struct ExecutionOutput { class Executable { public: explicit Executable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), @@ -162,7 +162,7 @@ class Executable { return hlo_profile_printer_data_ != nullptr; } - const HloModule& module() const { return *hlo_module_; } + HloModule& module() const { return *hlo_module_; } const bool has_module() const { return hlo_module_ != nullptr; } @@ -199,7 +199,7 @@ class Executable { // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. - const std::unique_ptr hlo_module_; + const std::unique_ptr hlo_module_; // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b1629616acd2bb715d5aa1a89286a38a45417d2c..bfd1b6cb1492f5cb709e2ecefe73782094e26f5e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -701,6 +701,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc index d7829045cc127deaa4c2c9b705dca5285d704af2..3a09d4d4716950a09d65dd093272482d55ac5c27 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -43,13 +43,14 @@ bool IsForwardConvolutionCanonical(const HloInstruction& conv) { // dilation), returns kPad and/or kSlice instructions that explicitly apply the // padding; otherwise returns the original input operand. When there is both // positive padding (including dilation) and negative padding, we insert both -// kPad and kSlice. +// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved +// into a kPad or kSlice op. HloInstruction* MaybePaddedAndSlicedInput( - const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, + Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasSymmetricPadding(conv_window) || - window_util::HasBaseDilation(conv_window)) { + if (!window_util::HasSymmetricPadding(*conv_window) || + window_util::HasBaseDilation(*conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. // @@ -62,12 +63,21 @@ HloInstruction* MaybePaddedAndSlicedInput( MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { int64 dim = conv_dnums.input_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_edge_padding_low( - std::max(0LL, conv_window.dimensions(i).padding_low())); - padding_config.mutable_dimensions(dim)->set_edge_padding_high( - std::max(0LL, conv_window.dimensions(i).padding_high())); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).base_dilation() - 1); + if (conv_window->dimensions(i).padding_low() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + conv_window->dimensions(i).padding_low()); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + conv_window->dimensions(i).padding_high()); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + if (conv_window->dimensions(i).base_dilation() != 1) { + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window->dimensions(i).base_dilation() - 1); + conv_window->mutable_dimensions(i)->set_base_dilation(1); + } } PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction( @@ -75,7 +85,7 @@ HloInstruction* MaybePaddedAndSlicedInput( input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } - if (window_util::HasNegativePadding(conv_window)) { + if (window_util::HasNegativePadding(*conv_window)) { // If the window has negative padding, insert a kSlice that explicitly // applies negative padding. // @@ -89,10 +99,14 @@ HloInstruction* MaybePaddedAndSlicedInput( int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. - start_indices[dim] += - std::max(0LL, -conv_window.dimensions(i).padding_low()); - limit_indices[dim] -= - std::max(0LL, -conv_window.dimensions(i).padding_high()); + if (conv_window->dimensions(i).padding_low() < 0) { + start_indices[dim] += -conv_window->dimensions(i).padding_low(); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() < 0) { + limit_indices[dim] -= -conv_window->dimensions(i).padding_high(); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } } input = @@ -140,25 +154,22 @@ bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( // Insert slices and/or pads between the convolution and its input and/or // kernel operand. + Window new_conv_window = conv->window(); HloInstruction* new_input = MaybePaddedAndSlicedInput( - conv->window(), conv->convolution_dimension_numbers(), + &new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(0)); HloInstruction* new_kernel = - MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), + MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(), conv->mutable_operand(1)); - // Remove the padding from convolution's window field. These paddings are - // made explicit with the inserted pads. - Window new_conv_window = conv->window(); + // Remove the window dilation from convolution's window field. These paddings + // are made explicit with the pads inserted by MaybePaddedKernel(). for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { WindowDimension* dim = new_conv_window.mutable_dimensions(i); // The size of the kernel may have changed so update the Window to match. dim->set_size(new_kernel->shape().dimensions( conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_base_dilation(1); dim->set_window_dilation(1); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index 4ce877f62a55c960765314670288ee626c5fc15b..e81850db69edced29ea31bb2a526b0503bf8a453 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -77,7 +77,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { return false; } - if (window_util::HasWindowReversal(conv->window())) { + // CuDNN can perform either cross correlation (no reversal), + // or convolution (all dimensions reversed). + if (dnums.input_spatial_dimensions_size() == 2 + ? !window_util::AllOrNoneReversed(conv->window()) + : window_util::HasWindowReversal(conv->window())) { return false; } return true; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 492d290bf4a27a91fa14dea95ac62d90bc1fa28a..3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -138,6 +138,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, const int num_dimensions = window.dimensions_size(); CHECK_LE(num_dimensions, 3); + CHECK_GE(num_dimensions, 1); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. // This matches the behavior of TF (see definition of conv1d in @@ -148,10 +149,15 @@ Status RunCudnnConvImpl(CudnnConvParams params, output_shape.element_type()) << ShapeUtil::HumanString(output_shape); + // If one dimension is reversed, we need to have all dimensions reversed (so + // we're doing convolution not cross correlation). + const bool dims_reversed = window.dimensions()[0].window_reversal(); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dims_reversed, dim.window_reversal()); CHECK_EQ(dim.padding_low(), dim.padding_high()); CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " @@ -198,6 +204,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); convolution_descriptor.set_group_count(feature_group_count); + convolution_descriptor.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -363,14 +370,12 @@ StatusOr GetCudnnConvParams( params.output_shape = &conv_result_shape; params.fusion.emplace(); auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); } + fusion.mode = static_cast( + backend_config.activation_mode()); fusion.side_input_scale = backend_config.side_input_scale(); params.input_buf = operand_buffers[0]; params.filter_buf = operand_buffers[1]; diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6dcdaf1cfe06e446deed847aaf29088a7ed10e13..2ab754a471070d5f90a3eaebd0600ff180d2fe5d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -161,6 +161,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); + HloOpcode opcode = op->opcode(); + + if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum + : llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 30c1f9088968305ad0207164ecb07ba13cc89ee6..470457935acacb8940af241dadb393d770786939 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -229,7 +229,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - (user->fusion_kind() == HloInstruction::FusionKind::kInput && + (IsReduceInputFusion(*user) && LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 57426327822d95a42f407ed7488f35acfd3623d2..ae2e718db29803a085401969a7d9b09abf690a6c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -51,7 +51,7 @@ GpuExecutable::GpuExecutable( const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 0e276282e40fba0ae4881a51dad0c7c9e8d1c081..2b3c77f5b82aa94f44d8de56caf0f4d31c05e0cb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { GpuExecutable(const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2d31fd5570c468b0c42fa308535fd335f3588a79..452e763a8eaadc805cd3a3859a68e2a31598fd36 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -55,7 +55,7 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr) { if (instr.IsMultiOutputFusion()) { for (const HloInstruction* operand : instr.fused_expression_root()->operands()) { @@ -67,17 +67,70 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { return true; } } - return false; - } else if (instr.opcode() == HloOpcode::kFusion) { - if (IsReductionToVector(*instr.fused_expression_root())) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) - << " Fusion rooted at reduction-to-vector op must be of kind kInput: " - << instr.ToString(); - return true; + } else if (instr.opcode() == HloOpcode::kFusion && + IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + return IsReduceInputFusion(instr) || IsReductionToVector(instr); +} + +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2) { + // Returns the instructions that determines the emitter used for lowering, + // sometimes referred to as "the real hero". + auto get_real_hero = + [&](const HloInstruction* instr) -> const HloInstruction* { + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduction-to-vector operand of the + // fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionToVector(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } + return fused_expression_root; } + return instr; + }; + + // Multi-output fusion kernels share a common parallel loop. The loop + // dimenstions are determined by instruction shapes. + auto get_loop_shape = [&](const HloInstruction* element_instr) { + // Special-case reduction-to-vector ops: The loop dimensions are determined + // by the shape of the first operand. + if (IsReductionToVector(*element_instr)) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // All shapes of the root tuple of multi-output fusions should agree, i.e. all + // root ops should have equal output shapes. An exception are + // reduction-to-vector ops. Here the input shapes of the reduction (first + // operand shape) and the reduction dimensions need to match. + auto* instr_1 = get_real_hero(&instr1); + auto* instr_2 = get_real_hero(&instr2); + // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. + if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) && + (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) || + instr_1->dimensions() != instr_2->dimensions())) { return false; } - return IsReductionToVector(instr); + // The elementwise output shapes must be the same (including layout). + // TODO(tjoerg): Further relax the constraint. The datatype does not matter. + return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1), + get_loop_shape(instr_2)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index f7c24a0d5bbfcc61389ea19ae7f769671e4e974d..e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -33,16 +33,29 @@ namespace gpu { bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, const HloInstruction& reduce); -// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` -// is either an unfused reduction-to-vector op, an input fusion rooted at a -// reduction-to-vector op, or a multi-output input fusion with at least one -// reduction-to-vector op root. // Note that reduction ops are lowered in different ways. Reduce input fusions // are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at // reduction-to-vector ops. Other reduction ops are lowered by // GpuElementalIrEmitter and fused like elementwise ops. + +// Whether `instr` is an input fusion rooted at a reduction-to-vector op or a +// multi-output input fusion with at least one reduction-to-vector op root. +bool IsReduceInputFusion(const HloInstruction& instr); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); +// Whether instruction shapes are compatible for multi-output fusion, i.e. +// whether the emitters support lowering the resulting fusion. +// This function works for both, sibling and producer-conumser multi-output +// fusion. +// So far, multi-output fusion is supported for loop fusions and reduce +// input fusions only. It is up to the caller to ensure the instructions +// themselves are fusible! +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index d91b7bc61fda5a07c163a07ec0e1644d2ad9db49..15d4ee206ce8debcb8a5dbc6ec65d29ba257d302 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -178,7 +178,7 @@ TEST_F(GpuFusibleTest, EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -191,10 +191,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -207,10 +208,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -225,10 +227,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -243,10 +246,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -263,11 +267,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + IsReduceInputFusion_MultiOutputInputReduceFusionWithExtraOutputs) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -284,10 +289,11 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -304,11 +310,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + IsReduceInputFusion_MultiOutputLoopFusionReduceAndElementwiseOp) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -325,8 +332,304 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_LoopFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_IgnoreFpPrecision) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + ROOT convert = f16[6400]{0} convert(p0.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Reduce) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Elementwise) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* div = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *div)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_DifferentLayouts) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{0,1,2} parameter(1) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{0,1} reduce(p1, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{0,1}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_ReduceFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={0}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_NoReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + // Note that reduce is not a reduction-to-vector. + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 91609c730b6c0d666eb607fb42b918c0f8f250e5..1126943624a3771433ecac591545d335c1890115 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -37,7 +37,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { public: GpuHloOrdering(const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order); + const std::vector& thunk_launch_order); ~GpuHloOrdering() override = default; // Only the entry computation can possibly be sequentially ordered, and only @@ -56,7 +56,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { GpuHloOrdering::GpuHloOrdering( const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order) + const std::vector& thunk_launch_order) : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { @@ -150,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering( // However, if the total order is A,B,D,C,E, then C and E can run // concurrently. void BFSLaunchOrder(const HloComputation* computation, - std::vector* launch_order) { + std::vector* launch_order) { // This topological sort uses two data structures: // 1. `incoming_edge_count` which keeps track of the number of incoming // edges to each HLO; @@ -158,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation, // // The sorting algorithm repeatedly pops the top from the queue and deletes // that HLO from the graph, making more HLOs incoming-edge free. - std::deque queue; + std::deque queue; std::unordered_map incoming_edge_count; - for (const auto& hlo : computation->instructions()) { + for (auto* hlo : computation->instructions()) { if (hlo->operand_count() == 0) { queue.push_back(hlo); } else { @@ -172,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation, } while (!queue.empty()) { - const HloInstruction* x = queue.front(); + HloInstruction* x = queue.front(); queue.pop_front(); launch_order->push_back(x); - for (const HloInstruction* y : x->users()) { + for (HloInstruction* y : x->users()) { --incoming_edge_count[y]; if (incoming_edge_count[y] == 0) { queue.push_back(y); @@ -195,14 +195,14 @@ StatusOr> GpuHloSchedule::Build( std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. - const HloComputation* entry_computation = module.entry_computation(); + HloComputation* entry_computation = module.entry_computation(); if (stream_assignment.StreamCount() == 1) { // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( HloInstructionSequence sequence, ScheduleComputation( - *entry_computation, [pointer_size](const BufferValue& buffer) { + entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); schedule->thunk_launch_order_ = sequence.instructions(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 07a7fc67aa555845c3de57e574ab582403ec0490..7f224ffe4f03f8f05b0f1907628d99d9df387770 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -46,7 +46,7 @@ class GpuHloSchedule { // Returns the total order of thunk launches, represented in terms of HLO // instructions. - const std::vector& ThunkLaunchOrder() const { + const std::vector& ThunkLaunchOrder() const { return thunk_launch_order_; } @@ -60,7 +60,7 @@ class GpuHloSchedule { private: GpuHloSchedule(); - std::vector thunk_launch_order_; + std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 6d3aed15ebe7d925eda00a72177a03a2264a640c..91db7151f22fd75b20244878bee86d65acd1d304 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -33,7 +33,7 @@ namespace gpu { class GpuHloScheduleTest : public HloTestBase { protected: - using HloVec = std::vector; + using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewUnverifiedModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 1c0a23fa3eb38961d420aff05e412c3b4d8524e7..f59da2caa18646676297e66dd329c66fb5fddf1b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -65,8 +65,8 @@ HeuristicLayoutAssignment(const HloInstruction* instr, VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // Empirically we've found with Volta and cudnn 7 that backward-input convs - // with stride are significantly faster with NCHW layouts. + // Empirically we've found with Volta and cudnn <= 7.3 that backward-input + // convs with stride are significantly faster with NCHW layouts. // // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), // which on paper gives good performance. However, there are two observations: @@ -75,11 +75,17 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // * we've also observed that for mixed layouts, cuDNN transposes data back // and forth from a different layout combination. If we end up with // transposes anyway, we prefer to have them in XLA, as they can be fused. - // TODO(timshen): Figure out the exact condition. This may be achieved by - // auto-tuning layouts offline. - if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && - window_util::HasStride(instr->window())) { - return kAllNCHW; + if (auto* dnn = stream_executor->AsDnn()) { + auto version_status = dnn->GetVersion(); + if (version_status.ok()) { + auto version = version_status.ConsumeValueOrDie(); + if (std::make_tuple(version.major_version(), version.minor_version()) <= + std::make_tuple(7, 3) && + instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return kAllNCHW; + } + } } // For other Volta f16 convolutions, use NHWC. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 8cc76c872c61634ca4344d8a8cdf8c6a75aea2ac..2ffc8bfb49b205dced0d540ba72426e72d95e596 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 43f43b50e4a6478f343088194871cc9d380bd2d2..6151dd8ff4c92bb81bd756c68cc9377633c8c9d5 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -80,7 +80,7 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { // This function limits the maximum number of operands to a fusion. // // There's a cap on how many parameters we can pass to a CUDA kernel, but -// exactly what that limit is is hazy, as it depends on (among other things) how +// exactly what that limit is hazy, as it depends on (among other things) how // much GPU constant memory is in use for other purposes. // // Moreover, we don't even know at the point that we're running fusion how many @@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } } else if (consumer->operand_count() == 2 && - consumer->opcode() == HloOpcode::kAdd) { + consumer->opcode() == HloOpcode::kAdd && + consumer->operand(other_operand_index) != producer) { // Fuse a bias add into the output of the dot. return true; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index fb77bc4b8eb497d09014da96769b52aa606510af..688604cd36e5a45debf855aacd29d05ecda92341 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) { op::Parameter())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool fused_something, + GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { @@ -723,7 +746,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 42fb38dffae31b0f4322216545027e067cab250d..33e41a2782b5932430eea621d3cea2c6634f292f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -268,5 +268,17 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { + return b->CreateAnd( + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index f373d4a8393a047aba599b0fae954e98a740161e..ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -155,6 +155,10 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Emits code that determines whether the current thread is thread 0 within +// block 0 of the kernel. +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7fcdd805ed32004a96ecc0da7de1d89bcf1b6229..6693f66d62d8b04d1b78e001fdb515b34539c67f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -63,9 +63,6 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -97,6 +94,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + VLOG(2) << "HandleAddDependency: " << add_dependency->ToString(); + const HloInstruction* operand = add_dependency->operand(0); + // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value + // sometimes, e.g., when it's operand is a constant or a bitcast of a + // constant. + if (bindings_.BoundToIrValue(*operand)) { + bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand)); + } + return Status::OK(); +} + Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { auto operand = get_tuple_element->operand(0); CHECK(bindings_.BoundToIrValue(*operand)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 56c3f452006f9e2d5c37cc3b54701b2367abfa14..2da46c016935d0e927879bbfb0d05cfc4899d818 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -100,6 +100,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 87b6cd640acc41074c40e1d397b9334b76029fd5..fb040aff30d48bf5817946ce53d37bc6685941e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -65,11 +65,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -88,6 +88,8 @@ limitations under the License. namespace xla { namespace gpu { +using llvm_ir::KernelMappingScheme; + namespace { using absl::InlinedVector; @@ -1188,7 +1190,7 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce), index_ty); } -static std::pair ComputeTilingSchemeForReduction( +static std::pair ComputeKernelMappingSchemeForReduction( int64 depth, int64 width, int64 kWarpSize) { constexpr int64 kTargetNumElementsPerThread = 64; int64 x_tile_size = kTargetNumElementsPerThread; @@ -1322,7 +1324,7 @@ Status IrEmitterUnnested::EmitRowReduction( int64 x_tile_size; int64 z_tile_size; std::tie(x_tile_size, z_tile_size) = - ComputeTilingSchemeForReduction(depth, width, kWarpSize); + ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. @@ -2171,7 +2173,18 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); + int64 dimension_to_sort = sort->dimensions(0); + // In case there is a 'values' parameter that is a iota, we take note and use + // it later to ensure a stable sort. Otherwise, we don't guarantee a stable + // sort. + int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { + if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && + ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && + Cast(sort->operand(i))->iota_dimension() == + dimension_to_sort) { + iota_values_parameter_index = i; + } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -2196,7 +2209,6 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } } - int64 dimension_to_sort = sort->dimensions(0); uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); @@ -2298,8 +2310,9 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } } return llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_masks, - &b_, launch_dimensions, + dimension_to_sort, keys_array, values_arrays, + iota_values_parameter_index, IrName(sort), xor_masks, &b_, + launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, kTileSize); @@ -2385,7 +2398,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } @@ -3103,8 +3116,18 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multioutput fusion, we need to emit each operand and the root. + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing register + // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the + // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); + TF_RETURN_IF_ERROR( + KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); + })); + + // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -3113,8 +3136,6 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( &hlo, launch_dimensions.launch_bound(), &b_))); b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); } @@ -3146,31 +3167,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } -int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_in_reduced_shape_arrays->reserve(num_outputs); - output_reduced_shapes->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), - reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); - } - } else { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - hlo.shape().element_type(), reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); - } - return num_outputs; -} int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, @@ -3199,335 +3195,508 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the -// thread lives within a square tile of size tile_size (so thread blocks are of -// size tile_size * tile_size). -std::tuple CalculateYXCoordinateWithinTile( - llvm::IRBuilder<>* builder, llvm::Value* tile_size, - int64 threads_per_tile) { - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_tile, - llvm::cast(thread_id)); - thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), - /*isSigned=*/true, "thread.id.x"); - auto x = builder->CreateURem(thread_id, tile_size); - auto y = builder->CreateUDiv(thread_id, tile_size); - return std::make_tuple(y, x); -} - -// Reads block_idx.x, casts it to type index_ty, and adds the assumption that -// it's in the range [0, num_blocks]. -llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, - int64 num_blocks) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); -} - -// Emits code to process up to (tile_size/num_rows) elements in a tile, given -// `emit_elem_function` is the function to emit code to process one element, `y` -// and `x` are the coordinates for the first element to process, and `index` is -// the index for the origin of the tile. Emits bounds check to ensure that each -// processed element is within the boundary defined by `tile_width` and -// `tile_height`. +void EmitFullTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + for (int64 i = 0; i < tile_size_y; i += num_threads_y) { + IrArray::Index source_idx_y = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), + KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + } +} + +void EmitPartialTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + + ksl->IfReturnVoid( + "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + // tile_height_bound = + // ceil(tile_height / num_threads_y) * num_threads_y + llvm::Value* ceiling_of_ratio = builder->CreateUDiv( + builder->CreateAdd(tile_height, llvm::ConstantInt::get( + index_ty, num_threads_y - 1)), + llvm::ConstantInt::get(index_ty, num_threads_y)); + llvm::Value* tile_height_bound = builder->CreateMul( + ceiling_of_ratio, + llvm::ConstantInt::get(index_ty, num_threads_y)); + ksl->ForReturnVoid( + loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/tile_height_bound, + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), + [&] { + emit_elem_function( + source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc); + }); + }); + }); + } +} + +// Emits code to process up to +// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, +// given `emit_elem_function` is the function to emit code to process one +// element, `y` and `x` are the intra-tile coordinates for the first element +// to process, and `index` is the index for the origin of the tile. Information +// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits +// bounds check to ensure that each processed element is within the boundary +// defined by `tile_width` and `tile_height`. void EmitTiledElementalCodeWithBoundsCheck( - int64 tile_size, int64 num_rows, const IrArray::Index& index, - const string& loop_name, KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - // Emits a constant value with index type. - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - auto emit_full_tile = [&] { - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(index_typed_constant(i), y); - emit_elem_function(source_idx, y_loc); - } - }; - - auto emit_last_row = [&] { - ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { - // tile_height_upper_bound = - // ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, - index_typed_constant(num_rows - 1)), - index_typed_constant(num_rows)), - index_typed_constant(num_rows)); - ksl->ForReturnVoid( - loop_name, /*start=*/index_typed_constant(0), - /*end=*/tile_height_upper_bound, - /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { - auto y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), - y_loc); - }); - }); - }); - }; ksl->IfReturnVoid( "full_tile", builder->CreateAnd( - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), - emit_full_tile, emit_last_row); + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), + tile_width), + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), + tile_height)), + [&] { + EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, + emit_elem_function); + }, + [&] { + EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, tile_height, tile_width, index_ty, + emit_elem_function); + }); } } // namespace -// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose -// algorithm to improve the memory access patterns for the input parameters -// which have a shape that is a 0-2-1 transpose of the output tensors. +// Emits code to process a tensor element in a tile for the given kCopy HLO that +// performs a 0-2-1 transpose. // -// For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape of -// three components 0-1-2 in the order major to minor. The x- and y- dimensions -// of the tensors are tiled in square tiles of edge length `kTileSize`. Each -// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each -// thread copies kTileSize/kNumRows elements from the input to a shared memory -// tile, then the otherwise "regular hlo kernel" reads from the shared memory -// instead of the original input. +// index: The index for the first output element in the normalized tensor. The +// normalized tensor is the resulting tensor after collapsing contiguous +// dimensions that play the same role in the transpose. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +// kernel_info: Other information to support the kernel code generation. +void IrEmitterUnnested::EmitTileElementForCopy( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(tiled_param_info->GetBufferForParameter(0), + {b_.getInt64(0), x_loc, y_loc}), + "output_element"); + llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); + Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( + hlo->shape().element_type(), + kernel_info->GetKernelMappingScheme()->GetDimensionsInElements()); + // When the output_reduced_shape is a 0-2-1 transpose of the input shape, + // the 0-2-1 transpose is achieved through EmitWriteArrayElement. + output_array.CastToShape(output_reduced_shape, &b_) + .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_); +} + +// Emits code to process a tensor element in a tile for the given kLoop fusion +// HLO containing parameters that are 0-2-1 transpose of its outputs. // -// This is similar to the following CUDA algorithm in TensorFlow: -// https://goo.gl/MStRV6. -// -// `kTileSize` should usually be same as warp size. We currently choose 32 for -// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles. -LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, absl::Span reduced_output_dims, - absl::Span tiled_param_ids) { - // Parameters for the tiling algorithm. - constexpr int64 kTileSize = 32; - constexpr int64 kNumRows = 4; - constexpr int64 kThreadsPerTile = kTileSize * kNumRows; - - // Construct IrArrays for the inputs and outputs. +// index: The index for the first output element in the normalized tensor, that +// is the resulting tensor after collapsing contiguous dimensions that play +// the same role in the transpose. +// kernel_info: Other information to support the kernel code generation. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +void IrEmitterUnnested::EmitTileElementForFusion( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); - int64 num_outputs = output_arrays.size(); - std::vector param_arrays = ConstructIrArrayForInputs(*hlo); - int64 num_params = param_arrays.size(); + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elem_emitter); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = + kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, output_arrays[0].GetShape()); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + DCHECK(output_value->getType()->isStructTy()); + DCHECK_EQ(output_value->getType()->getStructNumElements(), + output_arrays.size()); + for (int64 i = 0; i < output_arrays.size(); ++i) { + output_arrays[i].EmitWriteArrayElement( + untiled_index, ExtractValue(output_value, i), &b_); + } + } else { + output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_); + } +} + +// Emits a block of tiles, given a function object to emit one tile. +void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, + llvm::Type* index_ty) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); + absl::Span dims_in_block = + mapping_scheme->GetDimensionsInBlocks(); + absl::Span block_sizes = mapping_scheme->GetBlockSizes(); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Emit all the tiles for a given dimension in a tile block. + auto emit_tiles_for_block_dim = + [&](const string& loop_name, const IrArray::Index& starting_tile, + int dim_id, + const std::function + emit_next_block_dim) { + if (block_sizes[dim_id] == 1) { + emit_next_block_dim(starting_tile); + } else { + llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id]; + llvm::Value* block_size_for_dim = + index_typed_constant(block_sizes[dim_id]); + llvm::Value* block_id_for_dim = + b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); + llvm::Value* last_block_for_dim = + index_typed_constant(dims_in_block[dim_id] - 1); + llvm::Value* last_block_size_for_dim = index_typed_constant( + dims_in_tile[dim_id] - + (dims_in_block[dim_id] - 1) * block_sizes[dim_id]); + llvm::Value* num_tiles_in_block = + Select(ICmpEQ(last_block_for_dim, block_id_for_dim), + last_block_size_for_dim, block_size_for_dim); + + ksl.ForReturnVoid( + loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); + } + }; + + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + const bool block_contains_multi_tiles = + mapping_scheme->GetNumberOfTilesInOneBlock() > 1; + + // Emit the tile with a given tile_index, by calculating the tight bounds for + // each dimension of the tile and then calling emit_one_tile. + auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { + std::vector output_tile_bounds(3); + for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; + ++i) { + int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i); + // Only last row or column may not have full size. + llvm::Value* is_last_row = + ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1)); + int64 partial_row_size = + reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; + output_tile_bounds[i] = + Select(is_last_row, index_typed_constant(partial_row_size), + index_typed_constant(tile_size_for_dim), "tile_bound"); + } + + IrArray::Index tile_origin = + mapping_scheme->GetElementIndexForTileOrigin(tile_index); + emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles); + }; + const IrArray::Index starting_block = + mapping_scheme->EmitBlockIndex(index_ty); + const IrArray::Index starting_tile_for_dim_z = + mapping_scheme->GetTileIndexForBlockOrigin(starting_block); + + // Emit the three dimensional block of tiles. + emit_tiles_for_block_dim( + "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ, + [&](const IrArray::Index& starting_tile_for_dim_y) { + emit_tiles_for_block_dim( + "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY, + [&](const IrArray::Index& starting_tile_for_dim_x) { + emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x, + KernelMappingScheme::DimX, + emit_one_tile_for_tile_index); + }); + }); +} + +// Emits a kernel for the hlo instruction using the given kernel mapping scheme. +// +// unnested_hlo: The unnested hlo instruction for which the kernel is generated. +// Currently, these hlo instructions are supported: kLoop fusion, kCopy. +// tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of +// other tensors with the same dimensions and need to be tiled and tranposed. +// mapping_scheme: The tiling scheme to use. +// kernel_generator: Contains function objects for code generation, such as +// element generator, block prologue and epilogue generators. +// kernel_info: Represent other information to support the code generation +// of the tiled kernel for the hlo. +LaunchDimensions IrEmitterUnnested::EmitKernel( + HloInstruction* unnested_hlo, absl::Span tiled_param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + + std::vector param_arrays = ConstructIrArrayForInputs(*unnested_hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); for (int64 id : tiled_param_ids) { - const HloInstruction* param = hlo->operand(id); - // Add 1 to the minor dimension to reduce shared memory bank conflicts. - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - param->shape().element_type(), module_), - kTileSize + 1), - kTileSize); - auto* tile_base_ptr = llvm_ir::AllocateSharedMemoryTile( - b_.GetInsertBlock()->getParent()->getParent(), tile_type, - IrName(hlo, StrCat("tile", id))); - param_shmem_buffers[id] = tile_base_ptr; + const HloInstruction* param = unnested_hlo->operand(id); + param_shmem_buffers[id] = + mapping_scheme->GetSharedMemoryBufferForElementType( + llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), + module_), + IrName(unnested_hlo, StrCat("tile", id))); VLOG(3) << "Added shmem buffer for parameter " << id << ": " - << llvm_ir::DumpToString(*tile_base_ptr); + << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result - // for the purpose of tiling. Calculate the logical output dimensions in the - // tile from the reduced output dimensions. - std::vector output_dims_in_tiles = std::vector( - reduced_output_dims.begin(), reduced_output_dims.end()); - CHECK_EQ(output_dims_in_tiles.size(), 3); - for (int i = 1; i < 3; ++i) { - output_dims_in_tiles[i] = - CeilOfRatio(output_dims_in_tiles[i], kTileSize); - } - const int64 num_tiles = - absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); - LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); - - llvm::Type* index_ty = - GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); + LaunchDimensions launch_dimensions = LaunchDimensions( + mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - // Cast each output IrArray to its corresponding reduced shape and keep the - // reduced shape live during IR emission. - std::vector output_in_reduced_shape_arrays; - std::vector output_reduced_shapes; - CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, - &output_in_reduced_shape_arrays), - num_outputs); + // For multioutput fusion, one thread needs to output a tuple with pointers to + // all the individual outputs. We could do this at any point in the kernel, + // but we do it at the beginning in the hopes of reducing register pressure, + // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (unnested_hlo->IsMultiOutputFusion()) { + TF_CHECK_OK(KernelSupportLibrary(&b_).If( + "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + return Status::OK(); + })); + } // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; std::vector param_reduced_shapes; - CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( - *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, - ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), - num_params); + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape( + *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays); + DCHECK_EQ(num_shapes, num_params); // Calculate the starting element coordinate within a tile for the current // thread, (y, x) from thread_id. llvm::Value* x; llvm::Value* y; - std::tie(y, x) = CalculateYXCoordinateWithinTile( - &b_, index_typed_constant(kTileSize), kThreadsPerTile); - - // Calculate the index for the current output tile from block_id. - const IrArray::Index output_tile_index( - GetBlockIdx(&b_, index_ty, num_tiles), - ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, - output_dims_in_tiles), - &b_); - - // Output tile origin is the index for the first element of the current output - // tile. - const IrArray::Index output_tile_origin = [&] { - IrArray::Index index = output_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); - } - return index; - }(); - - // Calculate the input tile origin from the output tile origin. - const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim())); + std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty); - // Calculate the current output tile bounds in each of the logical dimensions. - std::vector output_tile_bounds(3); - for (int i = 1; i < 3; ++i) { - // Only last row or column may not have full size. - output_tile_bounds[i] = - Select(ICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); - } + kernel_info->SetLaneId( + mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x + : nullptr); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { - EmitTiledElementalCodeWithBoundsCheck( - kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, - tile_height, emit_elem_function); + llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, + &ksl, &b_, y, x, tile_height, + tile_width, emit_elem_function); }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = Add(index[dim], addend); - return index; - }; - const IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); - - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); - } - }); + auto emit_one_tile = [&](const IrArray::Index& output_tile_origin, + absl::Span output_tile_bounds, + bool block_contains_multi_tiles) { + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + const IrArray::Index input_index = + input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + // If shared memory transpose is needed, wait for all threads to reach this + // point, lest we copy a value from tile to output before the other thread + // copies it from input to tile. This is `__syncthreads` in CUDA. + if (!tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } - llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + const IrArray::Index output_index = + output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values for - // the tiled parameters are read from the shmem buffers. - if (hlo->opcode() == HloOpcode::kCopy) { - emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = - Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, load_from_shmem_buffer, &b_); - }); - } else { - CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + // Write to output[index] by emitting code like normal, except that values + // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elem_emitter); - tiled_param_info.set_y(y_loc); - fused_emitter.SetTiledParameterInfo(&tiled_param_info); - TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); - IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( - index, output_reduced_shapes[0], output_arrays[0].GetShape(), - &b_); - const llvm_ir::ElementGenerator& output_generator = - fused_emitter.GetRootGenerator(); - llvm::Value* output_value = - output_generator(untiled_index).ValueOrDie(); - if (hlo->IsMultiOutputFusion()) { - CHECK(output_value->getType()->isStructTy()); - CHECK_EQ(output_value->getType()->getStructNumElements(), - output_in_reduced_shape_arrays.size()); - for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { - output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, ExtractValue(output_value, i), &b_); - } - } else { - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, output_value, &b_); - } + output_index, "output", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + kernel_generator.GetTileElementGenerator()(unnested_hlo, index, + kernel_info, y_loc, x_loc); }); + // If a tile block contains multiple tiles and shared memory buffers are + // used, we need to wait for all threads to finish using the shared memory + // buffer for the current tile before we move on to process the next tile + // and overwrite the shared memory buffers. + if (block_contains_multi_tiles && !tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } + }; + + const BlockPrologueGenerator& block_prologue_generator = + kernel_generator.GetBlockPrologueGenerator(); + if (block_prologue_generator) { + block_prologue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with all the individual outputs. - if (hlo->IsMultiOutputFusion()) { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); + EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + + const BlockEpilogueGenerator& block_epilogue_generator = + kernel_generator.GetBlockEpilogueGenerator(); + if (block_epilogue_generator) { + block_epilogue_generator(unnested_hlo, kernel_info); } return launch_dimensions; } +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// with a shape that is a 0-2-1 transpose of the output tensor shape. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape +// of three components 0-1-2 in the order major to minor. The x- and y- +// dimensions of the tensors are tiled in square tiles with an edge length +// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads +// transposes one tile: each thread copies kTileSize/kNumRows elements from +// the input to a shared memory tile, then the otherwise "regular HLO kernel" +// reads from the shared memory instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more +// efficient to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { + constexpr int kNumRows = 4; + KernelMappingScheme mapping_scheme( + reduced_output_dims, /*tile_size_y=*/kWarpSize, + /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1}, + /*num_threads_y=*/kNumRows, + /*num_threads_x=*/kWarpSize, &b_); + TileElementGenerator element_generator; + if (hlo->opcode() == HloOpcode::kCopy) { + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + }; + } else { + DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); + }; + } + KernelCodegenInfo kernel_info(&mapping_scheme); + KernelCodeGenerator kernel_generator(std::move(element_generator)); + return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info); +} + namespace { // Returns true to indicate it is safe to use the tile based shared memory // transpose implementation to implement the kernel for the instruction. @@ -3562,8 +3731,8 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { ? ShapeUtil::GetSubshape(hlo->shape(), {0}) : hlo->shape(); - // If the output_shape is reduced to 021 shape, find all the parameters of the - // hlo that are in the corresponding 012 shape. + // If the output_shape is reduced to 021 shape, find all the parameters of + // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); @@ -3600,9 +3769,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { } // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the - // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb - // shared memory per SM. (This is increased to 96kb in Volta, but we don't - // use this, in part because it eats into our L1 cache space.) + // elements are of size 4 bytes), and CUDA has an architectural limit of + // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we + // don't use this, in part because it eats into our L1 cache space.) // // For correctness we need to ensure that we don't make more than 48kb worth // of shmem tiles per block. And for performance, we'd probably like to use @@ -3610,9 +3779,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // gpu core. // // We say without benchmarks that we want at least 3 threads/block, - // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose - // which params get the shmem transpose treatment arbitrarily; it's not clear - // if there's a Right Choice. + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We + // choose which params get the shmem transpose treatment arbitrarily; it's + // not clear if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use // shared memory in fusions. If in the future other fusible ops use shared @@ -3666,10 +3835,10 @@ Status IrEmitterUnnested::EmitConstantGlobals() { } // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in the - // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely - // preserves their names (like available_externally), we also need to ensure - // that they stick around even if they're "unused". + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". // // We may have to be more more clever here in the future if we notice that // we're keeping around too many globals because of their linkage. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 334c0b3c20b0888fa9b167a8979221f0184a82e7..e09ed657a812be6ab4859a0e365a51c45a37bfed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { @@ -47,6 +48,94 @@ namespace gpu { // class IrEmitterUnnested : public IrEmitter { public: + // Parameter block_contains_multi_tiles indicates whether a tile block + // consists of multiple tiles or not. If the tile block contains only one + // tile, there is no need to use atomic operation to accumulate a local result + // to a global result to implement reduction. + using TileGenerator = + std::function output_tile_bounds, + bool block_contains_multi_tiles)>; + // KernelCodegenInfo records the common information to support the code + // generation for a kernel to process tensor elements by blocks. A block of + // tensor elements may contain one or multiple tiles. The code generators that + // generate code for tile elements or block prologue/epilogue refer to this + // class in their prototypes. If the implementations of such code generators + // require other information that are specific to the HLO instructions, the + // implementations need to define and use derived classes of this class. + class KernelCodegenInfo { + public: + explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) + : mapping_scheme_(mapping_scheme), + tiled_param_info_(nullptr), + lane_id_(nullptr) {} + + void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { + CHECK_EQ(tiled_param_info_, nullptr); + tiled_param_info_ = tiled_param_info; + } + + llvm::Value* GetLaneId() const { return lane_id_; } + llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { + return mapping_scheme_; + } + llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { + return tiled_param_info_; + } + + private: + llvm_ir::KernelMappingScheme* mapping_scheme_; + llvm_ir::TiledParameterInfo* tiled_param_info_; + llvm::Value* lane_id_; + }; + + // A function object to prepare for the code generation for a tile block. + using BlockPrologueGenerator = + std::function; + // A function object to finalize the code generation for a tile block. + using BlockEpilogueGenerator = + std::function; + // A function object to generate code to process one element in a tile. + // + // hlo: the instruction for which the code is generated for. + // index: the index for the first output element of the current thread. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. + // kernel_info: Other information to support the kernel code generation. + using TileElementGenerator = std::function; + + // KernelCodeGenerator records the code generator objects that generate code + // for tile elements or tile block prologue/epilogue. + class KernelCodeGenerator { + public: + explicit KernelCodeGenerator( + TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator = {}, + BlockEpilogueGenerator block_epilogue_generator = {}) + : tile_element_generator_(std::move(tile_element_generator)), + block_prologue_generator_(std::move(block_prologue_generator)), + block_epilogue_generator_(std::move(block_epilogue_generator)) {} + + const TileElementGenerator& GetTileElementGenerator() const { + return tile_element_generator_; + } + const BlockPrologueGenerator& GetBlockPrologueGenerator() const { + return block_prologue_generator_; + } + const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { + return block_epilogue_generator_; + } + + private: + TileElementGenerator tile_element_generator_; + BlockPrologueGenerator block_prologue_generator_; + BlockEpilogueGenerator block_epilogue_generator_; + }; + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context); @@ -82,7 +171,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -205,22 +294,32 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids); + // Emits a kernel for an unnested HLO instruction. + LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, + absl::Span param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info); + void EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, llvm::Type* index_ty); + // Emits code to process a tensor element in a tile for the given kCopy HLO + // that performs a 0-2-1 transpose. + void EmitTileElementForCopy(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given kLoop fusion + // HLO containing parameters that are 0-2-1 transpose of its outputs. + void EmitTileElementForFusion(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. std::vector ConstructIrArrayForInputs( const HloInstruction& hlo); - // For each output of the `hlo` instruction, constructs the reduced shape for - // the output with the given `reduced_output_dims` and cast the original - // output IrArray element in `output_arrays` to the reduced shape. Returns - // the number of outputs. - int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, - const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in // `param_buffers` to find out whether the input has a reduced shape. If the // input has a reduced shape, constructs the reduced shape for the input and diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2..bd53b90b42d8e657a3ee58e7ca03fb60522aae28 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -177,13 +177,6 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math(), - &target_options); - - // Enable FMA synthesis. - target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -206,8 +199,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -401,8 +393,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, int32 opt_level = hlo_module_config.debug_options().xla_backend_optimization_level(); - CHECK_GE(opt_level, 2) - << "The XLA GPU backend doesn't support unoptimized code generation"; + if (opt_level < 2) { + LOG(ERROR) << std::string(80, '*'); + LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code " + "generation but "; + LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level + << "!"; + LOG(ERROR) << "(Supported configuration is " + "--xla_backend_optimization_level >= 2.)"; + LOG(ERROR) << std::string(80, '*'); + } AddOptimizationPasses(opt_level, /*size_level=*/0, target_machine.get(), &module_passes, @@ -453,18 +453,21 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // * 3-6 gives similar results as 2; // * >6 start hurting the performance of at least dot product kernels. // - // TODO(jingyue): The current threshold only considers the numbr of IR + // TODO(jingyue): The current threshold only considers the number of IR // instructions which do not accurately reflect the true cost. We need a // better cost model. FeedLLVMWithFlags({"-bonus-inst-threshold=2"}); - // TODO(b/22073864): Increase limit when scan memory dependency. - // This helps to reduce more redundant load instructions. + // Increase limit when scanning memory dependencies. This helps to reduce + // more redundant load instructions. // // The specific value is currently large enough for s3d in shoc benchmark, // which contains a lot of load instructions and many arithmetic instructions // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + // Use div.approx -- it matters for some float-division heavy benchmarks. + FeedLLVMWithFlags({"-nvptx-prec-divf32=0"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); // Initialize the NVPTX target; it's the only target we link with, so call its diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index d9b06828e2b5d334873c88cb49c2e0d5675bb5fe..01fddcede64d1bb02ab89db5fc9524893c2d47a4 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -41,50 +41,7 @@ GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, HloInstruction* instr2) { - auto get_element_instr = - [&](const HloInstruction* instr) -> const HloInstruction* { - const HloInstruction* element_instr = instr; - if (instr->opcode() == HloOpcode::kFusion) { - auto fused_expression_root = instr->fused_expression_root(); - if (instr->IsMultiOutputFusion()) { - // If possible, we want to pick a reduce operand of the fusion root, - // because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionToVector(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - } else { - element_instr = fused_expression_root; - } - } - return element_instr; - }; - - auto get_element_shape = [&](const HloInstruction* element_instr) { - // Special handling of kReduce instructions -- the fusion - // applies to the first operand. - if (IsReductionToVector(*element_instr)) { - return element_instr->operand(0)->shape(); - } - return element_instr->shape(); - }; - - // The shapes in all tuple operands should agree, unless it is a reduce. - // In that case, the operand of the reduce needs to have the same shape - // as the other tuple operands, but also we need to compare the output - // shapes of the reduces. - auto* element_instr_1 = get_element_instr(instr1); - auto* element_instr_2 = get_element_instr(instr2); - if (element_instr_1->opcode() == HloOpcode::kReduce && - element_instr_2->opcode() == HloOpcode::kReduce && - !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { - return false; - } - // The elementwise output shapes must be the same (including layout). - return ShapeUtil::EqualIgnoringFpPrecision( - get_element_shape(element_instr_1), get_element_shape(element_instr_2)); + return ShapesCompatibleForMultiOutputFusion(*instr1, *instr2); } bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { @@ -205,7 +162,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } - if (!ShapesCompatibleForFusion(producer, consumer)) { + if (!ShapesCompatibleForMultiOutputFusion(*producer, *consumer)) { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index dc221f22a74f0875e08e01890ce8ac8fe072cd9d..d16c87ba5c63aa582753fe949e9e39ee2d8b81e5 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index de04ed85c30717f5be7c5485ff3b68270c8ec188..f3e17d888242a36c268dcbfa0d6530f80cedceb0 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -173,13 +174,16 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); - pass.AddPass( - /*is_layout_sensitive=*/false, + AlgebraicSimplifierOptions options( [](const Shape&, const Shape&) { return false; }); + options.set_enable_permutation_sort_replacement(true); + pass.AddPass(options); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -248,11 +252,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, + AlgebraicSimplifierOptions options( /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_permutation_sort_replacement(true); + pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. // @@ -473,7 +479,8 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { + int cc_minor, + bool disable_ptx_optimizations) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); @@ -513,6 +520,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (disable_ptx_optimizations) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -733,8 +743,9 @@ StatusOr> NVPTXCompiler::RunBackend( } } - const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + const std::vector cubin = CompilePtxOrGetCachedResult( + ptx, cc_major, cc_minor, + module->config().debug_options().xla_gpu_disable_ptxas_optimizations()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -766,9 +777,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -796,8 +807,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + StatusOr> maybe_cubin = CompilePtx( + *cache_ptx, cc_major, cc_minor, disable_ptx_optimizations); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() @@ -810,7 +821,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, // binaries are not available. We don't want to spam logs with // identical warnings in this case. - // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N // for more general usage. static std::atomic warning_done(false); log_warning = !warning_done.exchange(true); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f79ae2990ae7d6e6985b15727a72358289121aa9..be5e31a50112686841e6f18b76f382a56e61bafc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index f2ef11e1e6ac2405ac2a35fec7b79add9d2b6c17..31a5d7a8c04e9863830e2026fc73cd7ded8c322e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -30,7 +30,7 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewUnverifiedModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index d2f30ae7bc4f65675f10a2f87ba934cf308f663a..d917320e36363c4fa7e4c0055e8f3345cbc610a2 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,7 +26,7 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewUnverifiedModule(), with a flag for configuring + // Like HloTestBase::CreateNewVerifiedModule(), with a flag for configuring // the ftz option. std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 268b48a1cadeef911dfda7e827ae0cd154040be8..a1ed8499040359fe7265a7317b0577a990a2234c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index d0ccd8619bde9ddd560989380b403efed5c5f42c..5e524faab18947f5793dc2ae34e9329a446d4235 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -75,16 +75,16 @@ class GpuFtzDisabledTest : public GpuFtzTest { // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. TEST_F(GpuFtzEnabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.f32 - CHECK: mul.ftz.f32 - CHECK-NOT: mul.f32 + CHECK-NOT: mul.rn.f32 + CHECK: mul.rn.ftz.f32 + CHECK-NOT: mul.rn.f32 )"); } TEST_F(GpuFtzDisabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.ftz.f32 - CHECK: mul.f32 - CHECK-NOT: mul.ftz.f32 + CHECK-NOT: mul.rn.ftz.f32 + CHECK: mul.rn.f32 + CHECK-NOT: mul.rn.ftz.f32 )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index da8e513a2c3b61eb9f780ac628e4befeb918b939..6814be779e0b02c38e3bc7008f036b845d88cb6f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index ea1fee040dd536bcd1c4f8c5dd4f3aaa8dca32f9..3019215c015a4e0aa094a62424d650ced0de2a0e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 14285459b5a7fc0325dc5af80e57bef4ee4b7299..ca0a78034d7dc83d17ad72202914d95f37ac122b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 141f3219387940a08ef22cbcc0be0971a14c2cd6..6b2d76764a077dc6cfa3f9ddc6e525ab330323be 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { std::unordered_map hlo_to_thunk; @@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule( InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } - for (const HloInstruction* hlo : hlo_total_order) { + for (HloInstruction* hlo : hlo_total_order) { if (hlo_to_thunk.count(hlo)) { thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); } diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index d3352994f845a535233612a17e19107511ce0622..43b628a1baf0e79a3197f3cfad3547991642eaed 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -46,7 +46,7 @@ class ThunkSchedule { public: ThunkSchedule(std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + const std::vector& hlo_total_order); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c7f51127649664189050e2128ae1e56547358c23..2dce7749bbd8da2673ae607eee3d731d9917e8fe 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewUnverifiedModule()), + : module_(CreateNewVerifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index fad3215fc81e1012ddaa5a6458bc98f90ab97f18..dc40b9446ad1bffcb757543e52fc9ab20de6d52e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -258,7 +258,7 @@ class HeapSimulatorTracker { // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, - const std::vector& instruction_sequence) { + const std::vector& instruction_sequence) { HloModuleConfig config; module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); @@ -286,7 +286,7 @@ class HeapSimulatorTracker { // Similar to the single entry computation constructor above, but runs the // simulation over the entire module. void RunWholeModule( - const std::vector& full_module_sequence) { + const std::vector& full_module_sequence) { points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -294,7 +294,7 @@ class HeapSimulatorTracker { HloSchedule schedule(module_.get()); absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { - const HloInstruction* instruction = full_module_sequence[i]; + HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index dbab62f847e8ca5e0b46dfd4162a0f4222640252..414c63271245315f037d04924c9291a9cd5b7a77 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -51,7 +51,7 @@ message HloInstructionProto { string name = 1; string opcode = 2; - xla.Shape shape = 3; + xla.ShapeProto shape = 3; xla.OpMetadata metadata = 7; @@ -132,7 +132,7 @@ message HloInstructionProto { string custom_call_opaque = 53; // Shape of outfeed request. - xla.Shape outfeed_shape = 29; + xla.ShapeProto outfeed_shape = 29; // Describes the dimension numbers used for a dot operation xla.DotDimensionNumbers dot_dimension_numbers = 30; @@ -190,7 +190,7 @@ message HloInstructionProto { // 'operand_shapes_with_layout' must contain a shape with layout for each // operand. bool constrain_layout = 56; - repeated Shape operand_shapes_with_layout = 57; + repeated xla.ShapeProto operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -205,7 +205,8 @@ message HloComputationProto { repeated HloInstructionProto instructions = 2; // The program shape (with layout) of this computation. - xla.ProgramShape program_shape = 4; + + xla.ProgramShapeProto program_shape = 4; // The id of this computation. int64 id = 5; @@ -251,6 +252,41 @@ message HloInputOutputAliasProto { repeated AliasEntryProto entries = 1; } +message DynamicParameterBindingProto { + // A list of bindings which indicates that the `target_dim_num` in + // the subshape `target_param_index` of parameter `target_param_num` + // is a dynamic dimension and its real dynamic size is represented + // by `dynamic_param_index` in parameter `dynamic_param_num`. + // + // As an example, imagine we have a program: + // + // ENTRY main { + // a = f32[] parameter(0) + // b = f32[10] parameter(1) + // ROOT root = (f32[], f32[10]) tuple(%a, %b) + // } + // + // Let's say 'b' (param index 1) is a dynamic shape whose input has + // an upperbound of 10 and real size is determined at runtime.'a' + // represents the real size of b's first dimension. + // + // In this case, the fields are set in the following way: + // dynamic_param_num = 1 + // dynamic_param_index = {} + // target_param_num = 0 + // target_param_index = {} + // target_param_dim = 0 + message Binding { + int64 dynamic_param_num = 1; + repeated int64 dynamic_param_index = 2; + int64 target_param_num = 3; + repeated int64 target_param_index = 4; + int64 target_param_dim_num = 5; + } + + repeated Binding entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -262,7 +298,7 @@ message HloModuleProto { repeated HloComputationProto computations = 3; // The host program shape (with layout) of the entry computation. - xla.ProgramShape host_program_shape = 4; + xla.ProgramShapeProto host_program_shape = 4; // The id of this module. int64 id = 5; @@ -272,6 +308,8 @@ message HloModuleProto { // Describes alias information between inputs and outputs. HloInputOutputAliasProto input_output_alias = 8; + + DynamicParameterBindingProto dynamic_parameter_binding = 9; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 0c20d207ddbca82e2f87800d331d1bace39a512e..ff122b529bdcdcc69d2245136e19101902dbf957 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const { proto.add_instructions()->Swap(&instruction_proto); } proto.set_root_id(root_instruction()->unique_id()); - *proto.mutable_program_shape() = ComputeProgramShape(); + *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); return proto; } @@ -711,6 +711,8 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } +uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } + Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -795,7 +797,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -825,9 +827,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, const std::vector&) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, const std::vector&) const; Status HloComputation::Accept( const std::function& visitor_func) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fc7d2035e5bd0b99fa9e7a026430172f686019d4..c584e4c7ca5770533f28352b0df9dadd9dbe1860 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -264,6 +264,12 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; + // Generates a hash value of an HLO computation. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO computations, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -301,7 +307,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 1e7a6e197f5b6c3070b7cad2c14f62521290a4c9..0361c87428f6e4c031d95492a5bc782ad388e5b5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,19 +20,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = match; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,11 +257,11 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_THAT(copy, op::Copy(constant)); + EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant)))); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -274,12 +274,13 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } @@ -297,7 +298,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { ShapeTree indices_to_copy(constant->shape(), /*init_value=*/true); EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) .ValueOrDie(), - op::Copy(constant)); + GmockMatch(m::Copy(m::Op().Is(constant)))); } { @@ -330,10 +331,11 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); - EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), - copies_added.element({1}))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({0})), + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({1}))))); } { @@ -346,8 +348,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, + GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) == nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -363,8 +366,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) != nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -376,12 +380,12 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::AfterAll()); + EXPECT_THAT(copy, GmockMatch(m::AfterAll())); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { @@ -393,14 +397,15 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); // Only the array (second tuple element) should be copied. The token is passed // through transparently. - EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(copy, GmockMatch(m::Tuple( + m::GetTupleElement(m::Op().Is(tuple)), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); } TEST_F(HloComputationTest, CycleDetection) { @@ -440,16 +445,18 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); } @@ -466,7 +473,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -505,7 +512,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -540,7 +547,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -576,7 +583,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index d12f920722e20a3390a99f74c8a10c7c9e3fdf6c..4f81dc94e577a63c09ae4019e5e8158252c712ce 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -22,21 +22,22 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = xla::match; + using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { @@ -49,13 +50,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } @@ -70,13 +72,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } @@ -91,13 +94,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } @@ -138,7 +142,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } } @@ -165,7 +169,7 @@ TEST_F(HloConstantFoldingTest, Slice) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } @@ -190,7 +194,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -240,7 +244,8 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } const char* const kConstantFoldLargePad = R"( @@ -260,7 +265,7 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { EXPECT_FALSE(result); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Pad(op::Constant(), op::Constant())); + GmockMatch(m::Pad(m::Constant(), m::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index fdfb38b858c32ba5b092ec2db84d4bac487c3e78..df7d3826dbad1f264a5dc53312c062900155b0f6 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -419,6 +419,21 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { } Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAddDependency( + const HloInstruction* add_dependency) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 8ced9d776e150ac587e9ac3ed0beffbc38dc5503..33983119c9b00a248c0e8dcc5815c6367192dca3 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -101,6 +101,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleAddDependency(const HloInstruction* add_dependency) override; Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 6a15b3440c6f9bd2cac5ea10a0883330260b89e5..ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewUnverifiedModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 5dcf6bc985ff18fa6fc1ab5a5692914b4597d065..3ed3d3c11c71dc534f193ba3ffb556b0eb0c80e4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -466,6 +466,21 @@ bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { return changed; } +bool HloDataflowAnalysis::UpdateAddDependencyValueSet( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency); + const InstructionValueSet& operand_set = + GetInstructionValueSet(add_dependency->operand(0)); + InstructionValueSet& add_dependency_set = + GetInstructionValueSet(add_dependency); + if (operand_set != add_dependency_set) { + add_dependency_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -622,6 +637,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( HloInstruction* instruction) { // Recompute from operands. switch (instruction->opcode()) { + case HloOpcode::kAddDependency: + return UpdateAddDependencyValueSet(instruction); case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); case HloOpcode::kDomain: @@ -795,6 +812,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kAddDependency: case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index abac398c04fc4c418d8814a0097db4434bc1cd9c..ece17fc4c3ea0261474df5d53c088dd05016e1e4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -193,6 +193,7 @@ class HloDataflowAnalysis { bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); + bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); // Propagate the dataflow through the module. void Propagate(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 6422346c1011b95bb511a1fcdfee5c84647f0571..f7a1f19a6f52befd58a405d0e406d7d0d37a8e57 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewUnverifiedModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1877,6 +1877,30 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { } } +TEST_P(HloDataflowAnalysisTest, AddDependency) { + string module_string = R"( +HloModule AddDependency +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloDataflowAnalysis::Run(*module)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency); + + // The after-all and parameter should define a value. Add-dependency should + // not. + EXPECT_EQ(analysis->values().size(), 2); + EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 6c8095d39774b247e136442c92c8ecf17432701c..1fa4259a3e42286cbc911907eea563e6ca6f8611 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 7fcafafc097a623686ca98a7cb3c6256c7904f6d..3a7652a8dc856b23c8988c4676916c8199e78860 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -396,6 +397,16 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { + const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); + Literal result(bitcast->shape()); + TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); + memcpy(result.untyped_data(), operand_literal.untyped_data(), + operand_literal.size_bytes()); + evaluated_[bitcast] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -1046,8 +1057,15 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleAfterAll(HloInstruction* token) { - evaluated_[token] = LiteralUtil::CreateToken(); +Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) { + evaluated_[after_all] = LiteralUtil::CreateToken(); + return Status::OK(); +} + +Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + evaluated_[add_dependency] = + GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone(); return Status::OK(); } @@ -1279,10 +1297,10 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, key_value_vector.push_back( std::make_pair(keys_data[i], values_data[i])); } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); + std::stable_sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); std::vector result_keys; // We use a InlinedVector here because we need to convert it to an // absl::Span later, and this would not work with std::vector. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 07f8d0aad4af0b07303b4e485b3630cc75bcb519..45ed8131dc6b71f706fce45d65b206363dd79ac3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -144,6 +144,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Operations that are type-agnostic or always return a specific type, such as // HandleIsFinite where boolean is always returned. // + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -180,7 +182,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleAfterAll(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleSort(HloInstruction* sort) override; @@ -221,16 +225,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const Literal& operand_literal) { const auto shape = instruction->shape(); const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape), - ShapeUtil::HumanString(operand->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); Literal result(shape); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d95b6ad04f2c446b423a3aaef4de333ed2968883..4eaaab20ea0add17d9b49b1b2b97991af0438dcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -2765,6 +2767,33 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } +TEST_P(HloEvaluatorTest, Bitcast) { + // Regression test for b/114735354. + constexpr absl::string_view hlo_text_base = R"( +HloModule Bitcast + +ENTRY main { + param = %s[32,121]{1,0} parameter(0) + ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param) +} +)"; + string hlo_text; + if (use_bfloat16_) { + hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16"); + } else { + hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32"); + } + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + if (use_bfloat16_) { + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual.data())); + } else { + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); + } +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index ebed875eb4954bc9a9da3f182005fa3d44326493..b87fc3e34012e75ee07bff6c1e113dce404f83cb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,9 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloOpcodeString(hlo_instruction->opcode())); } - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive type. - template ::value>::type* = nullptr> @@ -596,7 +593,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) { + Status HandleDivide(HloInstruction* divide) override { return HandleDivide(divide); } @@ -1556,10 +1553,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& row_data = row_to_sort.data(); std::vector result_data(row_data.begin(), row_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); + std::stable_sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), {sort_dim_elements})); sorted_row.PopulateR1(absl::Span(result_data)); @@ -2546,12 +2543,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || - std::is_same::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - std::vector data(iota->shape().dimensions(iota->iota_dimension())); + // Avoid using std::vector since std::vector does not convert to + // absl::Span. + absl::InlinedVector data( + iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); auto result = LiteralUtil::CreateR1(data); @@ -2568,9 +2567,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + !(std::is_integral::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return InvalidArgument("Unsupported type for iota"); } @@ -2722,17 +2720,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -2756,19 +2745,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape()), - ShapeUtil::HumanString(ehs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c919dbd82d3668c477bf37074f1d56f8cb7d9506 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { + +namespace { + +StatusOr ReplaceGetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); + uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + return true; +} + +} // namespace + +StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { + bool changed = false; + HloProto proto; + *proto.mutable_hlo_module() = module->ToProto(); + for (auto* computation : module->computations()) { + for (auto instruction : computation->instructions()) { + TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + changed = changed || replaced; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..30f44c23a835b3bcc935caaa917e040e07c4e703 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass to replace a kGetDimensionSize instruction with a constant instruction. +class HloGetDimensionSizeRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "hlo-get-dimension-size-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86aebdd5b64240e6e07d8e8050c0c8681cce765 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloGetDimensionSizeRewriterTest : public HloTestBase { + protected: + HloGetDimensionSizeRewriterTest() {} +}; + +TEST_F(HloGetDimensionSizeRewriterTest, Ok) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = u32[] get-dimension-size(p), dimensions={0} + size1 = u32[] get-dimension-size(p), dimensions={1} + ROOT mul = u32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = u32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 05cc1593e4ef4fc52b94e0536628645b1fa2abbc..302eca656be53a3cec86ddbf05a7fa3925c5185b 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -111,11 +113,6 @@ class NodeFilter { result == kSomeUsersOmitted; } - bool ShowFusionSubcomputation(const HloInstruction* instr) const { - CHECK_EQ(instr->opcode(), HloOpcode::kFusion); - return Show(instr) && !SomeOrAllOperandsOmitted(instr); - } - private: std::function filter_; }; @@ -240,34 +237,28 @@ string HtmlLikeStringSanitize(absl::string_view s) { // it to a short string lets us tell the user what the subcomputation is without // drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { + namespace m = match; + if (computation->instruction_count() != 3) { return nullopt; } - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2) { - return nullopt; - } - - // Check that both of the operands to the root are parameters. - const HloInstruction* operand0 = root->operand(0); - const HloInstruction* operand1 = root->operand(1); - if (operand0->opcode() != HloOpcode::kParameter || - operand1->opcode() != HloOpcode::kParameter) { - return nullopt; - } - - // Check that the two operands of root are param0 and param1. All of the - // opcodes we recognize are commutative, so we're OK with either order. - auto n0 = operand0->parameter_number(); - auto n1 = operand1->parameter_number(); - if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { + const HloInstruction *param0, *param1; + if (!Match(root, m::Op() + .WithNumOperands(2) + .WithShape(m::Shape().IsEffectiveScalar()) + .WithBinaryOperandsAnyOrder( + m::Parameter(¶m0, 0) + .WithShape(m::Shape().IsEffectiveScalar()), + m::Parameter(¶m1, 1) + .WithShape(m::Shape().IsEffectiveScalar())))) { return nullopt; } - // If the params are reversed, check that the operation being performed is - // commutative. - if (n0 == 1) { + // If the params are reversed (i.e. operand0 is param1 and operand1 is + // param0), check that the operation being performed is commutative. + if (root->operand(0) == param1) { + CHECK_EQ(root->operand(1), param0); switch (root->opcode()) { case HloOpcode::kLe: case HloOpcode::kGe: @@ -279,13 +270,6 @@ optional MatchTrivialComputation(const HloComputation* computation) { } } - // Check that the root and params are all effective scalars. - if (!ShapeUtil::IsEffectiveScalar(root->shape()) || - !ShapeUtil::IsEffectiveScalar(operand0->shape()) || - !ShapeUtil::IsEffectiveScalar(operand1->shape())) { - return nullopt; - } - // If we recognize the root's opcode, we've successfully pattern-matched! switch (root->opcode()) { case HloOpcode::kAdd: @@ -578,7 +562,7 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { // Show the subcomputation if we're showing any of its members. return std::any_of( - computation_->instructions().begin(), computation_->instructions().end(), + subcomp->instructions().begin(), subcomp->instructions().end(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -987,6 +971,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -1267,12 +1252,12 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( class GraphRendererRegistry { public: - void AddRenderer(GraphRendererInterface* graph_renderer) { + void SetRenderer(std::shared_ptr graph_renderer) { tensorflow::mutex_lock lock(mu_); graph_renderer_ = graph_renderer; } - GraphRendererInterface* GetDefaultRenderer() { + std::shared_ptr GetDefaultRenderer() { tensorflow::mutex_lock lock(mu_); return graph_renderer_; } @@ -1284,20 +1269,21 @@ class GraphRendererRegistry { private: tensorflow::mutex mu_; - GraphRendererInterface* graph_renderer_ = nullptr; + std::shared_ptr graph_renderer_ GUARDED_BY(mu_); }; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper) { - GraphRendererRegistry::Default()->AddRenderer(dumper); +Registrar::Registrar(std::shared_ptr dumper) { + GraphRendererRegistry::Default()->SetRenderer(dumper); } namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, + int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. std::unordered_map nodes; @@ -1404,6 +1390,56 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } +// Gets a node filter that includes nodes on all paths from `from` to `to`. If +// the all-paths set contains more than max_nodes elements, includes the nodes +// on the shortest paths and sets hit_limit to true. +NodeFilter MakeNodeFromToFilter(const HloInstruction* from, + const HloInstruction* to, int64 max_nodes, + bool* hit_limit) { + *hit_limit = false; + + // Elements in the queue are paths through the graph. + std::deque> queue; + queue.push_front({from}); + + // Compute the set of nodes we want to show using a slightly-modified + // Djikstra's algorithm. The only real difference is, rather than stopping + // when we find a (shortest) path, we continue until we've found max_nodes + // nodes on some path. + std::unordered_set visited; + std::unordered_set to_display = {from, to}; + while (!queue.empty() && to_display.size() < max_nodes) { + std::vector path = std::move(queue.front()); + queue.pop_front(); + if (!visited.insert(path.back()).second) { + continue; + } + + for (const auto* user : path.back()->users()) { + if (user == to) { + auto it = path.begin(); + for (; it != path.end() && to_display.size() < max_nodes; ++it) { + to_display.insert(*it); + } + if (it != path.end()) { + *hit_limit = true; + } + } else if (!visited.count(user)) { + auto new_path = path; + new_path.push_back(user); + queue.push_back(std::move(new_path)); + } + } + } + + return NodeFilter([=](const HloInstruction* instr) { + if (instr == from || instr == to) { + return kHighlightNode; + } + return to_display.count(instr) ? kNormalNode : kHideNode; + }); +} + string SaveGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const string& dest_path) { @@ -1483,7 +1519,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) @@ -1491,6 +1527,29 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config) { + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); + } else { + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); + } + string graph = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7..de1eefab776f9c3d2c73959a5cd267e938a78a32 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -66,6 +66,12 @@ string DumpGraph(const HloComputation& computation, const string& label, string DumpNeighborhoodAround(const HloInstruction& node, int radius, bool show_backend_config = false); +// Dumps nodes on any of the paths from `from` to `to`. If there are more than +// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// paths. +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config = false); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. // @@ -87,13 +93,13 @@ void DumpText(const HloModule& module, const string& label, // Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper); + Registrar(std::shared_ptr dumper); }; -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ - static ::xla::hlo_graph_dumper::Registrar \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)(new factory, \ - ##__VA_ARGS__) +#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ + static ::xla::hlo_graph_dumper::Registrar \ + XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)( \ + std::make_shared(), ##__VA_ARGS__) // __COUNTER__ must go through another macro to be properly expanded #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 26786ee950b5421f79fc089d65f1395aae65d335..21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -93,7 +93,8 @@ StatusOr> HloInstruction::CreateFromProto( [&computation_map](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Shape shape(proto.shape()); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { // Ops migrated to subclasses. @@ -101,23 +102,23 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 3) << "BatchNormTraining instruction should have 3 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormTraining( - proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), - proto.feature_index()); + instruction = + CreateBatchNormTraining(shape, operands(0), operands(1), operands(2), + proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormInference instruction should have 5 operands but sees " << proto.operand_ids_size(); instruction = CreateBatchNormInference( - proto.shape(), operands(0), operands(1), operands(2), operands(3), + shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormGrad instruction should have 5 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + instruction = CreateBatchNormGrad(shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; @@ -127,7 +128,7 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); - instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + instruction = CreateFft(shape, operands(0), proto.fft_type(), absl::Span(fft_length)); break; } @@ -148,7 +149,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + instruction = CreateRecv(shape.tuple_shapes(0), operands(0), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: @@ -161,7 +162,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Reverse instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateReverse(proto.shape(), operands(0), + instruction = CreateReverse(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -170,7 +171,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Concatenate instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = - CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); + CreateConcatenate(shape, all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) @@ -188,7 +189,7 @@ StatusOr> HloInstruction::CreateFromProto( absl::MakeSpan(reduce_operands) .subspan(reduce_operands.size() / 2, reduce_operands.size()); instruction = - CreateReduce(proto.shape(), inputs, init_values, + CreateReduce(shape, inputs, init_values, std::vector(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); @@ -203,7 +204,7 @@ StatusOr> HloInstruction::CreateFromProto( auto sort_operands = all_operands(); HloInstruction* keys = sort_operands[0]; instruction = CreateSort( - proto.shape(), proto.dimensions(0), keys, + shape, proto.dimensions(0), keys, absl::Span(sort_operands).subspan(1)); break; } @@ -212,7 +213,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Transpose instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateTranspose(proto.shape(), operands(0), + CreateTranspose(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -221,7 +222,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Broadcast instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateBroadcast(proto.shape(), operands(0), + CreateBroadcast(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -229,7 +230,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Map instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateMap(proto.shape(), all_operands(), computations(0)); + instruction = CreateMap(shape, all_operands(), computations(0)); break; case HloOpcode::kSlice: { TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -242,8 +243,8 @@ StatusOr> HloInstruction::CreateFromProto( slice_limits.push_back(slice_dimensions.limit()); slice_strides.push_back(slice_dimensions.stride()); } - instruction = CreateSlice(proto.shape(), operands(0), slice_starts, - slice_limits, slice_strides); + instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits, + slice_strides); break; } case HloOpcode::kConstant: { @@ -253,7 +254,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = absl::make_unique(proto.shape()); + instruction = absl::make_unique(shape); } break; } @@ -284,55 +285,54 @@ StatusOr> HloInstruction::CreateFromProto( tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; - instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), - fused_computation); + instruction = + CreateFusion(shape, fusion_kind, all_operands(), fused_computation); break; } case HloOpcode::kRng: - instruction = - CreateRng(proto.shape(), proto.distribution(), all_operands()); + instruction = CreateRng(shape, proto.distribution(), all_operands()); break; case HloOpcode::kParameter: - instruction = CreateParameter(proto.parameter_number(), proto.shape(), - proto.name()); + instruction = + CreateParameter(proto.parameter_number(), shape, proto.name()); break; case HloOpcode::kGetTupleElement: TF_RET_CHECK(proto.operand_ids_size() == 1) << "GetTupleElement instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateGetTupleElement(proto.shape(), operands(0), - proto.tuple_index()); + instruction = + CreateGetTupleElement(shape, operands(0), proto.tuple_index()); break; case HloOpcode::kReducePrecision: TF_RET_CHECK(proto.operand_ids_size() == 1) << "ReducePrecision instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = - CreateReducePrecision(proto.shape(), operands(0), - proto.exponent_bits(), proto.mantissa_bits()); + instruction = CreateReducePrecision( + shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) && - (ShapeUtil::TupleElementCount(proto.shape()) == 2)) + TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + (ShapeUtil::TupleElementCount(shape) == 2)) << "Infeed should have a tuple shape with 2 operands, but has: " - << proto.shape(); - const Shape& data_shape = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + << shape; + const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0); TF_RET_CHECK(proto.operand_ids_size() == 1) << "Infeed instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; - case HloOpcode::kOutfeed: + case HloOpcode::kOutfeed: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Outfeed instruction should have 2 operands but sees " << proto.operand_ids_size(); + Shape outfeed_shape(proto.outfeed_shape()); TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); + ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape)); + instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1), + proto.outfeed_config()); break; + } case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " @@ -342,7 +342,7 @@ StatusOr> HloInstruction::CreateFromProto( all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( - proto.shape(), all_operands(), computations(0), + shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), @@ -352,7 +352,7 @@ StatusOr> HloInstruction::CreateFromProto( } case HloOpcode::kAllToAll: { instruction = CreateAllToAll( - proto.shape(), all_operands(), + shape, all_operands(), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end())); @@ -368,8 +368,8 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(proto.shape(), operands(0), - source_target_pairs); + instruction = + CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } case HloOpcode::kConvolution: { @@ -382,7 +382,7 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), + shape, operands(0), operands(1), std::max(proto.feature_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; @@ -394,7 +394,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + instruction = CreateReduceWindow(shape, operands(0), operands(1), proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: @@ -404,9 +404,9 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 2) << "SelectAndScatter should have 2 called computations but sees " << proto.called_computation_ids_size(); - instruction = CreateSelectAndScatter( - proto.shape(), operands(0), computations(0), proto.window(), - operands(1), operands(2), computations(1)); + instruction = CreateSelectAndScatter(shape, operands(0), computations(0), + proto.window(), operands(1), + operands(2), computations(1)); break; case HloOpcode::kCustomCall: if (proto.constrain_layout()) { @@ -414,16 +414,17 @@ StatusOr> HloInstruction::CreateFromProto( // vector of pointers essentially) so create a vector of shapes to pass // in. std::vector operand_shapes; - for (const Shape& shape : proto.operand_shapes_with_layout()) { - operand_shapes.push_back(shape); + for (const ShapeProto& shape_proto : + proto.operand_shapes_with_layout()) { + operand_shapes.emplace_back(shape_proto); } - instruction = CreateCustomCall( - proto.shape(), all_operands(), proto.custom_call_target(), - operand_shapes, proto.custom_call_opaque()); + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); } else { - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target(), - proto.custom_call_opaque()); + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + proto.custom_call_opaque()); } if (proto.has_window()) { static_cast(instruction.get()) @@ -443,8 +444,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Pad instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_padding_config()); - instruction = CreatePad(proto.shape(), operands(0), operands(1), - proto.padding_config()); + instruction = + CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -452,8 +453,8 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); - instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), - slice_sizes); + instruction = + CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); break; } case HloOpcode::kGather: { @@ -469,7 +470,7 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_slice_sizes()) { gather_slice_sizes.push_back(bound); } - instruction = CreateGather(proto.shape(), operands(0), operands(1), + instruction = CreateGather(shape, operands(0), operands(1), *gather_dimension_numbers, gather_slice_sizes); break; } @@ -485,16 +486,15 @@ StatusOr> HloInstruction::CreateFromProto( auto scatter_dimension_numbers = absl::make_unique( proto.scatter_dimension_numbers()); - instruction = - CreateScatter(proto.shape(), operands(0), operands(1), operands(2), - computations(0), *scatter_dimension_numbers); + instruction = CreateScatter(shape, operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); break; } case HloOpcode::kIota: TF_RET_CHECK(proto.dimensions_size() == 1) << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); - instruction = CreateIota(proto.shape(), proto.dimensions(0)); + instruction = CreateIota(shape, proto.dimensions(0)); break; case HloOpcode::kDot: { TF_RET_CHECK(proto.has_dot_dimension_numbers()) @@ -506,8 +506,8 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = absl::make_unique( - proto.shape(), operands(0), operands(1), - proto.dot_dimension_numbers(), precision_config); + shape, operands(0), operands(1), proto.dot_dimension_numbers(), + precision_config); break; } case HloOpcode::kDomain: { @@ -529,7 +529,7 @@ StatusOr> HloInstruction::CreateFromProto( exit_hlo_sharding = std::make_shared(sharding); } instruction = absl::make_unique( - proto.shape(), operands(0), + shape, operands(0), absl::make_unique(entry_hlo_sharding), absl::make_unique(exit_hlo_sharding)); break; @@ -537,11 +537,11 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kGetDimensionSize: TF_RET_CHECK(proto.operand_ids_size() == 1); TF_RET_CHECK(proto.dimensions_size() == 1); - instruction = CreateGetDimensionSize(proto.shape(), operands(0), - proto.dimensions(0)); + instruction = + CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); break; default: { - instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (const int64 operand_id : proto.operand_ids()) { instruction->AppendOperand(instruction_map.at(operand_id)); } @@ -855,6 +855,16 @@ HloInstruction::CreateCollectivePermute( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } +/* static */ std::unique_ptr +HloInstruction::CreateAddDependency(HloInstruction* data_operand, + HloInstruction* token_operand) { + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kAddDependency, data_operand->shape())); + instruction->AppendOperand(data_operand); + instruction->AppendOperand(token_operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -1394,6 +1404,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateAfterAll(new_operands); } break; + case HloOpcode::kAddDependency: + CHECK_EQ(new_operands.size(), 2); + clone = CreateAddDependency(new_operands[0], new_operands[1]); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1680,6 +1694,7 @@ bool HloInstruction::IdenticalSlowPath( // This opcode has complex or special behavior so just return false. case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: return false; // Remaining instructions with special values. @@ -1745,6 +1760,26 @@ bool HloInstruction::IdenticalSlowPath( return false; } +uint64 HloInstruction::Hash() const { + using tensorflow::Hash64Combine; + + uint64 hash_value = Hash64Combine(0, static_cast(opcode())); + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape())); + + if (!IsCrossModuleAllReduce()) { + if (!operands().empty()) { + for (size_t i = 0; i < operands().size(); ++i) { + hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + } + } + } + + hash_value = Hash64Combine(hash_value, InnerHash()); + return hash_value; +} + +uint64 HloInstruction::InnerHash() const { return 13; } + void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -1900,6 +1935,11 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } +HloInstruction* HloInstruction::while_init() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return operands_[0]; +} + HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2214,7 +2254,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); - *proto.mutable_shape() = shape_; + *proto.mutable_shape() = shape_.ToProto(); for (const HloInstruction* operand : operands_) { proto.add_operand_ids(operand->unique_id()); } @@ -2462,6 +2502,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleDomain(this); case HloOpcode::kAfterAll: return visitor->HandleAfterAll(this); + case HloOpcode::kAddDependency: + return visitor->HandleAddDependency(this); case HloOpcode::kIota: return visitor->HandleIota(this); case HloOpcode::kGetDimensionSize: @@ -2623,36 +2665,6 @@ Status HloInstruction::AcceptWithOperandOrder( return Status::OK(); } -namespace { - -// Returns true if the given order is a topological sort of the instructions -// it contains. -bool OrderIsTopologicalSort(const std::vector& order) { - // Create a map from instruction to its position in 'order'. - std::unordered_map order_position; - for (int i = 0; i < order.size(); i++) { - if (!order_position.insert({order[i], i}).second) { - // Instruction order[i] is duplicated in the order. - return false; - } - } - // Verify that the operand of each instruction in the order is also in the - // order *and* the operand's position is earlier (defs are before uses for - // all ops). - for (auto* instruction : order) { - for (auto* operand : instruction->operands()) { - if (!ContainsKey(order_position, operand) || - order_position.at(operand) >= order_position.at(instruction)) { - return false; - } - } - } - - return true; -} - -} // namespace - Status HloInstruction::Accept( const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); @@ -3022,6 +3034,16 @@ const PrecisionConfig& HloInstruction::precision_config() const { LOG(FATAL) << "Unimplemented method."; } +PrecisionConfig* HloInstruction::mutable_precision_config() { + if (auto* convolution = DynCast(this)) { + return convolution->mutable_precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->mutable_precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3064,6 +3086,10 @@ int64 HloInstruction::concatenate_dimension() const { return Cast(this)->concatenate_dimension(); } +int64 HloInstruction::dimension() const { + return Cast(this)->dimension(); +} + bool HloInstruction::IsRank2Transpose() const { auto transpose = DynCast(this); return transpose != nullptr && transpose->IsRank2Transpose(); @@ -3243,6 +3269,11 @@ absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } +void HloInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + return Cast(this)->set_all_reduce_id(all_reduce_id); +} + const ConvolutionDimensionNumbers& HloInstruction::convolution_dimension_numbers() const { if (auto convolution = DynCast(this)) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 818d4ede0f30f06d390daa70c508c6be6bbc38ce..a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -770,6 +770,9 @@ class HloInstruction { static std::unique_ptr CreateGetDimensionSize( const Shape& shape, HloInstruction* operand, int64 dimension); + static std::unique_ptr CreateAddDependency( + HloInstruction* data_operand, HloInstruction* token_operand); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -883,11 +886,15 @@ class HloInstruction { return false; } - // Use an explicit loop rather than ContainerEquals, because copying around - // std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; + // Two AllReduces are Identical if they have the same all_reduce_id. + // Their operands don't have to be Identical. + if (!IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } } } @@ -898,6 +905,12 @@ class HloInstruction { return IdenticalSlowPath(other, eq_computations); } + // Generates a hash value of an HLO instruction. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO instructions, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; @@ -997,6 +1010,8 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); + HloInstruction* while_init() const; + // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -1257,6 +1272,7 @@ class HloInstruction { // superior. // Precondition: opcode must be kConvolution or kDot. const PrecisionConfig& precision_config() const; + PrecisionConfig* mutable_precision_config(); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1317,6 +1333,9 @@ class HloInstruction { // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; + // Delegates to HloGetDimensionSizeInstruction::dimension. + int64 dimension() const; + // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; @@ -1435,6 +1454,7 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns data on the window in a windowed operation such as // convolution. @@ -1599,6 +1619,10 @@ class HloInstruction { const std::function& eq_computations) const; + // Generates a hash value specific to a particular type of an instruction. + // This function typically considers the inner root instruction. + virtual uint64 InnerHash() const; + // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4c765aa375cd788612d144484df041dd6cd989f4..1ea02cf9c03866a598bec0e5356f0eb31ad27755 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -370,6 +370,11 @@ HloAllReduceInstruction::HloAllReduceInstruction( AppendComputation(reduce_computation); } +void HloAllReduceInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + all_reduce_id_ = all_reduce_id; +} + HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. @@ -1367,6 +1372,10 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +uint64 HloFusionInstruction::InnerHash() const { + return fused_instructions_computation()->Hash(); +} + std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { @@ -1610,7 +1619,7 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); - *proto.mutable_outfeed_shape() = outfeed_shape(); + *proto.mutable_outfeed_shape() = outfeed_shape().ToProto(); return proto; } @@ -1862,7 +1871,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { if (layout_constrained()) { proto.set_constrain_layout(true); for (const Shape& shape : operand_shapes_with_layout_) { - *proto.add_operand_shapes_with_layout() = shape; + *proto.add_operand_shapes_with_layout() = shape.ToProto(); } } return proto; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index d43a8973ccff697c27462b611446215df71973a5..b5c28137a145667a977d39c9d3c40c6d36a8436e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -252,6 +252,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { } absl::optional all_reduce_id() const { return all_reduce_id_; } + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -742,6 +743,8 @@ class HloFusionInstruction : public HloInstruction { const HloInstruction& other, const std::function& eq_computations) const override; + uint64 InnerHash() const override; + // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, @@ -954,6 +957,7 @@ class HloConvolutionInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -1325,6 +1329,7 @@ class HloDotInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 3e2f8bcd52f9043f161197756a2060b28dded1d9..d6a2b292a3916b2ff85f278cf5cb9f1567df88fa 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 5269cad94d35be3dd1c009588bbe422ff1533364..d28e79d41ad5d58a8881cfb80d488684af26564f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -237,8 +237,4 @@ void PrintTo(const HloInstruction* inst, ::std::ostream* os) { *os << (inst ? inst->ToString() : "nullptr"); } -void PrintTo(HloInstruction* inst, ::std::ostream* os) { - PrintTo(const_cast(inst), os); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 170ec93a334903cdc314f1950675ef30bc4cda5a..235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -385,7 +385,6 @@ std::vector Pointers(const Container& container) { // Tell GMock to print HloInstruction* by value, so error messages are nice. // Has to be in the same namespace as 'HloInstruction'. void PrintTo(const HloInstruction* inst, ::std::ostream* os); -void PrintTo(HloInstruction* inst, ::std::ostream* os); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 234fcd266aa09e193849ffb4526599114dfe22fe..d2740bcce26f04c5d7c8b64cfdaea53e3c697855 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -73,7 +73,7 @@ class ListScheduler { // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. static StatusOr Run( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -98,7 +98,7 @@ class ListScheduler { // comparison operators. using Priority = std::pair; - ListScheduler(const HloComputation& computation, + ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -111,7 +111,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( @@ -126,13 +126,13 @@ class ListScheduler { // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } @@ -141,7 +141,7 @@ class ListScheduler { // Buffers live out of the computation have an implicit use at the end of // the computation. for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) + points_to_analysis.GetPointsToSet(computation->root_instruction()) .CreateFlattenedSet()) { ++unscheduled_use_count_[live_out_buffer]; } @@ -157,7 +157,7 @@ class ListScheduler { // HloInstruction, plus some cached metadata, saved for the purposes of making // BytesFreedIfScheduled fast. struct ReadyListEntry { - const HloInstruction* instruction; + HloInstruction* instruction; // The total size of all buffers defined by this instruction. int64 bytes_defined; @@ -171,7 +171,7 @@ class ListScheduler { }; // Creates a ReadyListEntry for the given instruction. - ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) { ReadyListEntry entry; entry.instruction = instruction; @@ -250,13 +250,13 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. absl::flat_hash_map unscheduled_pred_count; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { + for (HloInstruction* user : instruction->users()) { unscheduled_pred_count[user]++; } - for (const HloInstruction* succ : instruction->control_successors()) { + for (HloInstruction* succ : instruction->control_successors()) { unscheduled_pred_count[succ]++; } } @@ -275,7 +275,7 @@ class ListScheduler { ready_instructions[inst] = it; }; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { if (instruction->operands().empty() && instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); @@ -287,7 +287,7 @@ class ListScheduler { // schedule. auto best_it = ready_queue.end(); --best_it; - const HloInstruction* best = best_it->second.instruction; + HloInstruction* best = best_it->second.instruction; VLOG(2) << "Schedule instruction: " << best->ToShortString() << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); @@ -348,13 +348,13 @@ class ListScheduler { } } } - CHECK_EQ(schedule.size(), computation_.instruction_count()); - CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); + CHECK_EQ(schedule.size(), computation_->instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count()); return schedule; } - const HloComputation& computation_; + HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; // Computations are analyzed in post-order. When scheduling an instruction @@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes( } StatusOr ScheduleComputationHelper( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, const absl::flat_hash_map& memory_by_computation) { - VLOG(2) << "Computation: " << computation.name(); + VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, size_function, memory_by_computation); @@ -404,17 +404,17 @@ StatusOr ScheduleComputationHelper( } // namespace StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->instruction_count(); + int64 total_hlos = computation->parent()->instruction_count(); absl::flat_hash_map extra_users; absl::flat_hash_map total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; total_sizes[hlo] = 0; @@ -448,8 +448,8 @@ StatusOr DFSMemoryScheduler( total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } - CHECK_EQ(extra_users.size(), computation.instruction_count()); - CHECK_EQ(total_sizes.size(), computation.instruction_count()); + CHECK_EQ(extra_users.size(), computation->instruction_count()); + CHECK_EQ(total_sizes.size(), computation->instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -459,7 +459,7 @@ StatusOr DFSMemoryScheduler( sequence.push_back(hlo); return Status::OK(); }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&extra_users, &total_sizes](const HloInstruction* a, const HloInstruction* b) { if (extra_users[a] != extra_users[b]) { @@ -470,12 +470,12 @@ StatusOr DFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instruction_count()); + CHECK_EQ(sequence.size(), computation->instruction_count()); return sequence; } // namespace xla StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -485,16 +485,16 @@ StatusOr ListMemoryScheduler( } StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); + return HloInstructionSequence(computation->MakeInstructionPostOrder()); } StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -513,7 +513,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, HeapSimulator::MinimumMemoryForComputation( - computation, list_sequence, points_to_analysis, + *computation, list_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); @@ -522,7 +522,7 @@ StatusOr DefaultMemoryScheduler( size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, HeapSimulator::MinimumMemoryForComputation( - computation, dfs_sequence, points_to_analysis, + *computation, dfs_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); @@ -532,7 +532,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, HeapSimulator::MinimumMemoryForComputation( - computation, post_order_sequence, points_to_analysis, + *computation, post_order_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -555,17 +555,17 @@ StatusOr DefaultMemoryScheduler( } StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); + HloSchedule schedule(module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); + TuplePointsToAnalysis::Run(module)); absl::flat_hash_map memory_by_computation; - for (const auto* computation : module.MakeComputationPostOrder()) { + for (auto* computation : module->MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( - *computation, *points_to_analysis, size_function, + computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( @@ -583,11 +583,11 @@ StatusOr ScheduleModule( } StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function) { - CHECK(!computation.IsFusionComputation()); + CHECK(!computation->IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); + TuplePointsToAnalysis::Run(computation->parent())); absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); @@ -600,7 +600,7 @@ HloMemoryScheduler::HloMemoryScheduler( StatusOr HloMemoryScheduler::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, size_function_, algorithm_)); + ScheduleModule(module, size_function_, algorithm_)); TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index cca5dc493989811a0bb9790c3237e5468a3f2d67..7227bfb27c74758d2b79e404afc9eb97a1ca894d 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -36,14 +36,14 @@ namespace xla { // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function( - const HloComputation&, const TuplePointsToAnalysis&, + HloComputation*, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -51,7 +51,7 @@ StatusOr ListMemoryScheduler( // DFS-order scheduler StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -59,7 +59,7 @@ StatusOr DFSMemoryScheduler( // Naive Post Order scheduler StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -69,7 +69,7 @@ StatusOr PostOrderMemoryScheduler( // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -79,13 +79,13 @@ StatusOr DefaultMemoryScheduler( // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function); // A pass which schedules the HLO instructions in a module. The HloModule's diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 984a6266abb28f154a015e79645317e4e246fd0b..bc0d7e2bc00eab014f2660c95a51b966642eaee9 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK(module->schedule().Verify()); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -124,9 +124,9 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -172,15 +172,16 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -218,19 +219,19 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -252,7 +253,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction::CreateParameter(0, r1f32, "cond_param")); HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + LiteralUtil::CreateR1({0, 0, 0, 0}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(module->entry_computation()->instruction_count(), diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 14bf17f4be16f8cf820753bc9f0473029834f1f8..fe8371384c0fa3900a9022f101ff0b296439cf16 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -240,8 +240,10 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } *proto.mutable_host_program_shape() = - entry_computation_layout().ComputeProgramShape(); + entry_computation_layout().ComputeProgramShape().ToProto(); *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); + *proto.mutable_dynamic_parameter_binding() = + dynamic_parameter_binding().ToProto(); return proto; } @@ -255,7 +257,7 @@ StatusOr> HloModule::CreateFromProto( // the entry parameters and root. TF_RET_CHECK(proto.has_host_program_shape()) << "No program shape found in the proto"; - const auto& expected_program_shape = proto.host_program_shape(); + ProgramShape expected_program_shape(proto.host_program_shape()); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -325,6 +327,10 @@ StatusOr> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. + TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_, + DynamicParameterBinding::CreateFromProto( + proto.dynamic_parameter_binding())); + absl::flat_hash_set computation_names; absl::flat_hash_set instruction_names; absl::flat_hash_set computation_ids; @@ -363,9 +369,9 @@ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { TF_RET_CHECK(module.has_host_program_shape()) << "No program shape found in the proto"; - const auto& program_shape = module.host_program_shape(); + ProgramShape program_shape(module.host_program_shape()); - HloModuleConfig module_config(program_shape); + HloModuleConfig module_config(ProgramShape{program_shape}); module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 8a1f999e3ab076b87a651a915f4de93320e7067f..7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -103,11 +104,7 @@ class HloModule { HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module. - const HloComputation* entry_computation() const { - CHECK_NE(nullptr, entry_computation_); - return entry_computation_; - } - HloComputation* entry_computation() { + HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } @@ -135,6 +132,12 @@ class HloModule { return config_.entry_computation_layout(); } + // Generates a hash value of an HLO module. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO modules, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const { return entry_computation()->Hash(); } + // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -232,6 +235,16 @@ class HloModule { return input_output_alias_config_; } + // DynamicParameterBinding holds the list of bindings that indicates which + // parameter dimensions are dynamic and which parameters represent their + // runtime value. + DynamicParameterBinding& dynamic_parameter_binding() { + return dynamic_parameter_binding_; + } + const DynamicParameterBinding& dynamic_parameter_binding() const { + return dynamic_parameter_binding_; + } + // Returns an id that is unique to this module across all modules created over // the lifetime of this process. int unique_id() const { return unique_id_; } @@ -285,6 +298,9 @@ class HloModule { // alias_config indicates the alias information of input/output buffers that // are expected from the module. HloInputOutputAliasConfig input_output_alias_config_; + + // Bindings for dynamic parameter mapping. + DynamicParameterBinding dynamic_parameter_binding_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 3ae67e4e5ee90ca182c7c3d97a67d070431ce851..620cb7e01ad1a060915f5b73474f6950ab18122a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewUnverifiedModule(); - auto module_b = CreateNewUnverifiedModule(); + auto module_a = CreateNewVerifiedModule(); + auto module_b = CreateNewVerifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 70c7d70b41c5c7bc94d1fac83c0fcf71f155b5f0..127cfd165a5d8229cac3035f56a66f1bcfa734f3 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,8 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAddDependency, "add-dependency") \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ @@ -84,7 +86,6 @@ namespace xla { V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetDimensionSize, "get-dimension-size") \ - V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kImag, "imag") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index f5f99bece18cc637365118ddcd1273da05f4e1b6..ca6a154809be46d6a0305c29e2b89219de408019 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. TF_DCHECK_OK(schedule_.Verify()); for (const auto& computation_sequence : schedule_.sequences()) { - const std::vector& order = - computation_sequence.second.instructions(); + const auto& order = computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { InsertOrDie(&order_position_, order[i], i); } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 2ab8aa57f6ed4586c3376ee7c44126c0ed19ea0b..3ca77e60cd5275c22eb0e338cd5437fc44b49958 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4390145c6bd7484987b2851ef92336defffb388b..9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -47,11 +47,11 @@ const double kF16max = 65504; // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. -HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { +HloSchedule ScheduleFromInstructionOrder(HloModule* module) { HloSchedule schedule(module); - for (const HloComputation* computation : module->computations()) { + for (HloComputation* computation : module->computations()) { if (!computation->IsFusionComputation()) { - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { schedule.GetOrCreateSequence(computation).push_back(instruction); } } @@ -850,6 +850,15 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } break; } + case HloOpcode::kAddDependency: { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateAddDependency(operands[0], operands[1])); + break; + } case HloOpcode::kSort: { optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index c59bdc0a0b372d829ee61f0a048b7704498e0d0e..ab71f011ac9d77d00ddfb41aca7a224d26d416b7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,7 +30,7 @@ limitations under the License. namespace xla { namespace { -namespace op = ::xla::testing::opcode_matchers; +namespace m = ::xla::match; using absl::string_view; struct TestData { @@ -195,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) } )" @@ -587,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -1241,7 +1242,38 @@ ENTRY Sort { } )" + }, +// AfterAll with multiple operands +{ +"AfterAllWithMultipleOperands", +R"(HloModule AfterAllWithMultipleOperands + +ENTRY AfterAllWithMultipleOperands { + p0 = f32[] parameter(0) + token0 = token[] after-all() + token1 = token[] after-all() + ROOT after-all = token[] after-all(p0, token0, token1) } + +)" +}, +// AddDependency +// A dependency chain is created from 'neg' to 'exp' using tokens. +{ +"AddDependency", +R"(HloModule AddDependency + +ENTRY AddDependency { + p = f32[] parameter(0) + neg = f32[] negate(p) + token = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token) + exp = f32[] exponential(p_after_token) + ROOT sum = f32[] add(neg, exp) +} + +)" +}, }); // clang-format on } @@ -1862,7 +1894,8 @@ ENTRY ReduceR3ToR2 { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); ASSERT_NE(module->entry_computation(), nullptr); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } TEST_F(HloParserTest, ParseSharding) { @@ -1922,7 +1955,7 @@ TEST(HloParserSingleOpTest, SingleOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { @@ -1950,7 +1983,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, CanonicalOp) { @@ -1959,7 +1992,7 @@ TEST(HloParserSingleOpTest, CanonicalOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); EXPECT_EQ( computation->root_instruction()->ToString(HloPrintOptions::Canonical()), text); @@ -2013,7 +2046,11 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Fusion(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Op() + .WithOpcode(HloOpcode::kFusion) + .WithNumOperands(2) + .WithOperand(0, m::Parameter(0)) + .WithOperand(1, m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { @@ -2057,7 +2094,7 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Convolution(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); auto* convolution = Cast(computation->root_instruction()); EXPECT_EQ(convolution->feature_group_count(), 1); @@ -2121,8 +2158,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), - op::Multiply(), op::Parameter(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Parameter()), GmockMatch(m::Multiply()), + GmockMatch(m::Parameter()), GmockMatch(m::Add()))); } TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { @@ -2148,8 +2187,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), - op::Broadcast(), op::Multiply(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Parameter()), + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Multiply()), GmockMatch(m::Add()))); } TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 312b5d020c398feb7738d14a9cfa0928d5178948..51177f24f5ee702be96fc8b4530ed38a5798109f 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -113,7 +113,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; VLOG(3) << module.entry_computation_layout().ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index cf33668f5bfa64a7843efc76e9f6768d18533240..981d06ce101644ecce587c4bd2f7a12c8edf6548 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -48,7 +48,7 @@ StatusOr> CreateModuleFromProto( return std::move(module); } -StatusOr> EntryComputationParameterShapes( +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); @@ -57,15 +57,16 @@ StatusOr> EntryComputationParameterShapes( return NotFound("HloProto missing program shape."); } - std::vector parameter_shapes; + std::vector parameter_shapes; const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); - for (const Shape& shape : program_shape.parameters()) { + for (const ShapeProto& shape : program_shape.parameters()) { parameter_shapes.push_back(&shape); } return parameter_shapes; } -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 1db82dd6fcaa5d7fe7d65894c1021105f0b26266..31ea2aaffd9cdb76d21edbd0d4a03aa5f865f4f0 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -43,12 +43,13 @@ StatusOr> CreateModuleFromProto( // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. -StatusOr> EntryComputationParameterShapes( +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto); // Returns the shape of the output of the entry computation. The shape pointer // refers to the output shape inside of the given HloProto. -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 49e46ecd00ee4370f3e93746348373b79febed3d..48add75523f02005c70bc6baf69a6b7d5aa4f7ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector; // before arbitrary elements. class InstructionList { public: - explicit InstructionList(const std::vector& order) { + explicit InstructionList(const HloInstructionSequence& order) { int64 position = 0; Item* last = nullptr; - for (const HloInstruction* inst : order) { + for (HloInstruction* inst : order.instructions()) { // Add a new item to the linked list. Item* item = new Item; item->next = nullptr; @@ -151,7 +151,7 @@ class InstructionList { // to be monotonically increasing through the list, and so is still useful // for quickly(-ish) determining the order of arbitrary instructions in // the list. - item->instruction = const_cast(inst); + item->instruction = inst; item->position = position; position++; @@ -927,7 +927,7 @@ Item* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, - const std::vector& order) const { + const HloInstructionSequence& order) const { InstructionList instruction_list(order); MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, instruction_list); @@ -971,8 +971,7 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list( - schedule->sequence(computation).instructions()); + InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1184,7 +1183,7 @@ StatusOr HloRematerialization::RematerializeComputation( sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { - const HloInstruction* instruction = item->instruction; + HloInstruction* instruction = item->instruction; sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1235,10 +1234,8 @@ StatusOr HloRematerialization::Run(HloModule* module) { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - module->schedule() - .sequence(node.computation()) - .instructions())); + ComputePeakMemory(node.computation(), module->schedule().sequence( + node.computation()))); } return Status::OK(); }, diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 70d83c04f07ca7fd0139f586869e8fe688f958f4..a07d348041b72bba45c6fd1f726f2a0065d01e53 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass { // peak memory is the maximum total size of all live HLO instruction values at // any program point. 'order' is the order in which the HLO instructions will // be emitted which is used to determine lifespans of HLO values. - StatusOr ComputePeakMemory( - const HloComputation* computation, - const std::vector& order) const; + StatusOr ComputePeakMemory(const HloComputation* computation, + const HloInstructionSequence& order) const; // Returns the peak memory usage of the called computations for the given // instruction. Zero is returned if the instruction calls no computations. diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 3f0ca342b4c84216ddd5ee553848360d8bd1ff0b..5a9b820a9d7f58695383b21c9e2126cf98970c83 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -205,6 +205,40 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return ExecuteWithDeviceBuffers( + /*executable=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 2e934bf66ae43ea412f242030b874dddb6d3722d..bb792cf8c9825ff67ca33bbcf2c3c32b1a0ecb85 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -136,6 +136,21 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run as part of compilation. + StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes); + // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. @@ -152,11 +167,6 @@ class HloRunner { const Backend& backend() const; private: - // Creates an executable object given an HLO module. If run_hlo_passes is - // true, the HLO passes will be run before. - StatusOr> CreateExecutable( - std::unique_ptr module, bool run_hlo_passes); - // Creates a ServiceExecutableRunOptions object to configure a run on device, // using the provided stream object. If device_assignment is not nullptr, it // will be used to configure the replication parameters. Replicated executions diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a5780b7551a43f2b64f2ac61ef1bf6ce9e07eb16..8f6eb974c5179b420c8f961393ca923e0a3b3530 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -46,8 +46,8 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -81,9 +81,8 @@ StatusOr HloSchedule::ToProto() const { return std::move(proto); } -void HloSchedule::set_sequence( - const HloComputation* computation, - absl::Span sequence) { +void HloSchedule::set_sequence(const HloComputation* computation, + absl::Span sequence) { set_sequence(computation, HloInstructionSequence(sequence)); } @@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } @@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - absl::flat_hash_map> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's @@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule( // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; + std::queue worklist; - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { if (ids_in_schedule.count(instruction->unique_id()) == 0) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { @@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule( // Lambda which schedules all instructions on the worklist. auto schedule_worklist = [&]() { while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.front(); worklist.pop(); new_sequence.push_back(instruction); - std::vector* new_users = + std::vector* new_users = tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); if (new_users != nullptr) { // This just-scheduled instruction has users which are newly added to // the module. Update the number of unscheduled operands and push the // newly added instruction to the worklist if it is ready to // schedule. - for (const HloInstruction* new_user : *new_users) { + for (HloInstruction* new_user : *new_users) { unscheduled_operand_count.at(new_user)--; CHECK_GE(unscheduled_operand_count.at(new_user), 0); if (unscheduled_operand_count.at(new_user) == 0) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 0a714101ee587aa847fa674bbde5586287c51f33..486ddbf499de80c634bc497158cd79ca066cc866 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -35,14 +35,14 @@ class HloInstructionSequence { public: HloInstructionSequence() = default; explicit HloInstructionSequence( - absl::Span instructions) { - for (const HloInstruction* instruction : instructions) { + absl::Span instructions) { + for (HloInstruction* instruction : instructions) { push_back(instruction); } } // Adds the instruction to the end of the sequence. - void push_back(const HloInstruction* instruction) { + void push_back(HloInstruction* instruction) { instruction_sequence_.push_back(instruction); id_sequence_.push_back(instruction->unique_id()); } @@ -56,7 +56,7 @@ class HloInstructionSequence { int64 size() const { return instruction_sequence_.size(); } // Returns the sequence of HLO instructions. - const std::vector& instructions() const { + const std::vector& instructions() const { return instruction_sequence_; } @@ -65,7 +65,7 @@ class HloInstructionSequence { private: // The sequence as HloInstructions. - std::vector instruction_sequence_; + std::vector instruction_sequence_; // The sequence of HLO instructions, represented by their unique IDs. The // sequence is stored as both HloInstructions and unique IDs because the @@ -98,7 +98,7 @@ class HloSchedule { // Sets the sequence for the given computation to the given sequence. void set_sequence(const HloComputation* computation, - absl::Span sequence); + absl::Span sequence); void set_sequence(const HloComputation* computation, HloInstructionSequence sequence); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 1424569ac1f62e4b965876141f1eb40be4f15bea..0e56e6f760e35ddcb45c6f58771d78405a09acfe 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -56,10 +56,10 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); - const std::vector& entry_schedule = + const auto& entry_schedule = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(entry_schedule.size(), 6); @@ -90,7 +90,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -139,7 +139,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -183,7 +183,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 88329c899794a6e0f5102d181d6161fe17f89932..f5061304456e04ab40448861343ef201c9450dcf 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(const_cast(user)) > 0) { + domain.exit_domains.count(user) > 0) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 11994d99c93e9d51691e482a3e3233b06fb0d060..c1073911ea9dc3811c195e27bcbae9b00929ad17 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index b6670d409b92e8be42f5cdb40fba8d662ae83958..1f01b0bb365450a933da9cc443db5223c06903f0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -166,9 +166,6 @@ class HloValue : public BufferValue { // Whether this value is live out of the HLO module. bool live_out_of_module_ = false; - - // Whether this value is live out of its computation. - bool live_out_of_computation_ = false; }; std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 27fd685a69a0bbd95b1d8d266ce6177a6c557f55..77db7b098a38ff4efdcc7447935fae61561c9ff4 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -753,13 +753,19 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); + return CheckShape(token, ShapeUtil::MakeTokenShape()); +} + +Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) { + TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2)); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); + return CheckShape(add_dependency, add_dependency->operand(0)->shape()); } Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { - return CheckShape( - get_size, ShapeInference::InferGetDimensionSizeShape( - get_size->operand(0)->shape(), get_size->dimensions(0))); + return CheckShape(get_size, + ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimension())); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -1373,9 +1379,8 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) << "Instruction shouldn't change layouts " - << instruction->ToString() << " From " - << ShapeUtil::HumanString(result_shape) << " To " - << ShapeUtil::HumanString(operand_shape); + << instruction->ToString() << " From " << result_shape << " To " + << operand_shape; } } } @@ -1426,6 +1431,8 @@ StatusOr HloVerifier::Run(HloModule* module) { return target_metadata_->ShapeSize(shape); })); + TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 9fbfd6a21c1f1148801000169046fbcbb37934fe..e4d0c3d6957885f1d719fedb5a900de601e397f8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -95,6 +95,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status HandleGetDimensionSize(HloInstruction* get_size) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 5ddfe0a944f04f070f9bdb81697425ee417ac15a..4bc557e4e62e7df4e25fda86fe417e84129b464c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,6 +35,10 @@ namespace { using ::testing::HasSubstr; +std::unique_ptr CreateUnverifiedModule() { + return absl::make_unique("module", HloModuleConfig()); +} + // This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { @@ -66,7 +70,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -85,7 +89,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -104,7 +108,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -138,7 +142,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -303,7 +307,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -327,7 +331,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 20cc18f981574adf1d95c9f1f87c95634238db06..98246d5403e4aebc2f4d81e52145706355ddd9a9 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -481,8 +481,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { - { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, - { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } }) (reshape %indices to s32[]) 0->[]) )"; @@ -512,8 +512,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,6] s32[2,1,1,6] { - { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, - { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } }, + { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } }) (reshape %indices to s32[5]) 0->[2]) )"; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 7f2d7e7cffc6debaaf9b64fffc5a8a7037ecdaa3..7559ed1bab84b21a4d51bc38db999900befcfad7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -103,7 +103,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: case HloOpcode::kTupleSelect: @@ -116,7 +115,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSin: return ShapeUtil::ElementIsComplex(instruction.shape()); - // Expensive instructions. + // Expensive instructions or unusual instructions for which fusion is + // nonsensical. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: case HloOpcode::kAtan2: case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: @@ -455,8 +457,13 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = HloReachabilityMap::Build(computation_); - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph @@ -564,8 +571,8 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { - // A consumer operand may have been multii-output fused into a parallel - // consumer and thus be missing from the oridinal reachability map. + // A consumer operand may have been multi-output fused into a parallel + // consumer and thus be missing from the original reachability map. if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { reachability_ = HloReachabilityMap::Build(consumer->parent()); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 39904bd54b09a916d3e26e90c62cd6a202f9588d..58b7135cea7419f13d60ed510ecf7a88126aee48 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index a06d6113e84630df14ff68280c248cccb9afaf06..de9204011ce5ba8a9fc2871c6bd7120b6ed371b5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -37,7 +37,7 @@ namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, /*hlo_profile_index_map=*/nullptr), @@ -85,6 +85,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); + evaluator_->ResetVisitStates(); TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( *computation, arg_literals)); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 3b1ebce0c75457d65e6834c809fe488a9c4a159a..bda13d376360306c81230e41b01cefc6caff230d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -42,7 +42,7 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module, + InterpreterExecutable(std::unique_ptr hlo_module, std::unique_ptr evaluator); ~InterpreterExecutable() override; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737..e3e5fa71543baa309b3a68888b1b9bdfd43cfbd5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -78,9 +78,14 @@ port::Status XlaInterpreterExecutor::SynchronousMemcpy( return port::Status::OK(); } -bool XlaInterpreterExecutor::HostCallback(Stream *stream, - std::function callback) { - AsExecutorStream(stream)->EnqueueTask(callback); +bool XlaInterpreterExecutor::HostCallback( + Stream *stream, std::function callback) { + AsExecutorStream(stream)->EnqueueTask([callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index fbb99457847dca69a1901006d5d8ff713882f918..400c30515464ed5b00251fba303fef303a26b97b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -125,7 +125,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return port::Status{port::error::UNIMPLEMENTED, ""}; } - bool HostCallback(Stream *stream, std::function callback) override; + bool HostCallback(Stream *stream, + std::function callback) override; port::Status AllocateEvent(Event *event) override { return port::Status{port::error::UNIMPLEMENTED, ""}; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a90411922205c0006159ff99f35a70138b1bee4f..eddef850cf5250b85b564c1e6c92d1cc8ecd1a43 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2000,6 +2000,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( switch (instruction->opcode()) { case HloOpcode::kAbs: case HloOpcode::kAdd: + case HloOpcode::kAddDependency: case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kBitcastConvert: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 2400b7bb7c409a4dcb33e6e8f4b409738510f3d6..5c661bfacb08fe27f3cbdc1fb9db083315166008 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -27,10 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -42,11 +43,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = xla::match; using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { @@ -328,11 +328,10 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // %tuple.1 = Tuple(%copy) layout=({0,1}) // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) // - EXPECT_TRUE( - AlgebraicSimplifier(/*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return false; }) - .Run(m.get()) - .ValueOrDie()); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); @@ -343,7 +342,8 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // Verify the structure of the HLO graph. EXPECT_THAT(root, - op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); + GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)), + m::Tuple(m::Copy(m::Op().Is(constant)))))); } TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { @@ -947,9 +947,11 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); - EXPECT_THAT(root, op::Add(op::Parameter(), - op::Slice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy))))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -977,10 +979,11 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); EXPECT_THAT(root, - op::Add(op::Parameter(), - op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + GmockMatch(m::Add( + m::Parameter(), + m::DynamicSlice( + m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -1008,11 +1011,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); - EXPECT_THAT(root, - op::Add(op::Parameter(), - op::Concatenate(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, @@ -1039,7 +1043,8 @@ TEST_F(LayoutAssignmentTest, .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); + EXPECT_THAT(root, + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); } TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { @@ -1063,8 +1068,9 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); - EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), - op::ShapeWithLayout(shape_copy)))); + EXPECT_THAT(root, + GmockMatch(m::Slice( + m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy)))); } TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { @@ -1150,7 +1156,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); HloInstruction* root = m->entry_computation()->root_instruction(); - ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); } @@ -1166,7 +1172,7 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); HloInstruction* root = m->entry_computation()->root_instruction(); - ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); } @@ -1197,7 +1203,7 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 // The custom call should be partially encapsulated in kCopy instructions // because of the layout mismatches. ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter())))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); @@ -1223,7 +1229,7 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { AssignLayouts(m.get(), &computation_layout); ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall())); + GmockMatch(m::Copy(m::CustomCall()))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); @@ -1257,7 +1263,7 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); ASSERT_THAT(m->entry_computation()->root_instruction(), - op::Copy(op::CustomCall(op::Tuple()))); + GmockMatch(m::Copy(m::CustomCall(m::Tuple())))); const HloInstruction* custom_call = m->entry_computation()->root_instruction()->operand(0); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc..d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -108,6 +109,14 @@ class IrArray { Index(absl::Span multidim, llvm::Value* linear, const Shape& shape); + // Returns an index that adds `addend` to the given `dim` of the object. + Index AddOffsetToDim(llvm::Value* addend, int64 dim, + llvm::IRBuilder<>* b) const { + IrArray::Index index = *this; + index[dim] = b->CreateAdd(index[dim], addend); + return index; + } + const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738..c26711e526c9b89cdedcb6aed9f93d41dd25dc83 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -52,6 +52,29 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), dimensions); } + +// Given an index for a shape, return the equivalent new index if the shape is +// reshaped to another shape. +IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, + const Shape& reshaped_shape, + llvm::IRBuilder<>* b) { + auto bounds = shape.dimensions(); + auto minor_to_major = shape.layout().minor_to_major(); + llvm::Value* linear_index = index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = b->CreateMul( + index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, reshaped_shape, b); +} + } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -60,28 +83,30 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + std::vector permutation(a.dimensions().size()); + absl::Span minor_to_major_a = LayoutUtil::MinorToMajor(a); + std::vector major_to_minor_a(minor_to_major_a.rbegin(), + minor_to_major_a.rend()); + absl::Span minor_to_major_b = LayoutUtil::MinorToMajor(b); + std::vector major_to_minor_b(minor_to_major_b.rbegin(), + minor_to_major_b.rend()); + for (size_t i = 0; i < permutation.size(); ++i) { + permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]); } - auto segs = ConsecutiveSegments(perm); - if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { - Shape norm_a = + + std::vector segments = ConsecutiveSegments(permutation); + if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) { + Shape descending_layout_shape = ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape reduced_a = MergeDimensions(segs, norm_a); - auto reduced_a_dims = reduced_a.dimensions(); + Shape normalized_shape = MergeDimensions(segments, descending_layout_shape); + absl::Span normalized_dims = + AsInt64Slice(normalized_shape.dimensions()); std::vector dims_021; - if (2 == segs.size()) { + if (2 == segments.size()) { // The logical component-0 is of size one. - dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + dims_021 = {1, normalized_dims[1], normalized_dims[0]}; } else { - dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]}; } return dims_021; @@ -90,27 +115,117 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b) { - auto bounds = reduced_output_shape.dimensions(); - auto minor_to_major = reduced_output_shape.layout().minor_to_major(); - llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); - int64 multiplier = 1; - for (int i = 0; i < reduced_output_index.size(); ++i) { - int64 dim = minor_to_major[i]; - llvm::Value* addend = - b->CreateMul(reduced_output_index[dim], - reduced_output_index.GetConstantWithIndexType(multiplier), - "linearizing", - /*HasNUW=*/true, /*HasNSW=*/true); - linear_index = b->CreateAdd(linear_index, addend, "", - /*HasNUW=*/true, /*HasNSW=*/true); - multiplier *= bounds[dim]; +KernelMappingScheme::KernelMappingScheme( + absl::Span dims_in_elems, int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, int64 num_threads_y, + int64 num_threads_x, llvm::IRBuilder<>* b) + : b_(b), + dims_in_elems_(dims_in_elems), + tile_sizes_{1, tile_size_y, tile_size_x}, + num_threads_x_(num_threads_x), + num_threads_y_(num_threads_y) { + DCHECK_EQ(dims_in_elems_.size(), 3); + DCHECK_EQ(req_block_sizes.size(), 3); + + DCHECK_EQ(tile_size_y % num_threads_y_, 0); + DCHECK_EQ(tile_size_x % num_threads_x_, 0); + + dims_in_tiles_ = ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_); + block_sizes_.reserve(req_block_sizes.size()); + absl::c_transform(req_block_sizes, dims_in_tiles_, + std::back_inserter(block_sizes_), + [](const int64 requested_size, const int64 max_size) { + return std::min(requested_size, max_size); + }); + dims_in_blocks_ = ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_); + + VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; + VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; + VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") + << "]"; +} + +IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape) { + DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); + Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + unnormalized_shape.element_type(), GetDimensionsInElements()); + return GetReshapedIndex(normalized_shape_index, output_shape, + unnormalized_shape, b_); +} + +IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), + llvm::cast(block_id)); + llvm::Value* linear_block_id = + b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); + return IrArray::Index(linear_block_id, + ShapeUtil::MakeShapeWithDescendingLayout( + PRED /*arbitrary*/, dims_in_blocks_), + b_); +} + +IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( + const IrArray::Index& block_index) { + IrArray::Index tile_index = block_index; + for (int i = 0; i < block_sizes_.size(); ++i) { + tile_index[i] = b_->CreateMul( + block_index[i], + llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), + "block_origin." + std::to_string(i)); + } + return tile_index; +} + +IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( + const IrArray::Index& tile_index) { + IrArray::Index elem_index = tile_index; + for (int i = DimY; i < DimTot; ++i) { + elem_index[i] = + b_->CreateMul(tile_index[i], + llvm::ConstantInt::get(tile_index[i]->getType(), + GetTileSizeForDimension(i)), + "tile_origin." + std::to_string(i)); } + return elem_index; +} + +llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name) { + // If shared memory tranpose is needed, we use square tiles. + CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY()); + + // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is + // organized into 32-way. We usually use the warp size or a multiplier or a + // the warp size as the size for tiling. This may cause all elements in the + // same column of a tile use the same memory bank and therefore shared memory + // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer + // can reduce such shared memory bank conflicts. + llvm::Type* buffer_type = llvm::ArrayType::get( + llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1), + GetTileSizeForDimension(DimY)); + return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(), + buffer_type, buffer_name); +} - return IrArray::Index(linear_index, unreduced_output_shape, b); +std::tuple +KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { + // Calculate (y, x) coordinate of the thread in the 2D view of thread block + // defined by (num_thread_y, num_thread_x) from thread_id. + llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm::Value* thread_id_int = + b_->CreateIntCast(thread_id_raw, index_ty, + /*isSigned=*/true, "thread.id.x"); + llvm::Value* num_thread_x = + llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + return std::make_tuple(y, x); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 5ea05b3188a1c0881e4c0c41625d530aff1b1205..06002d57b0d7daa07f903feebe67a60a083c0e7c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -28,23 +28,160 @@ namespace llvm_ir { // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical // components to 0-2-1. We call the shape being transposed the input shape and -// the transposed shape the output shape. The logical view of the input and -// output shapes for the transpose are called the 0-1-2 shape or reduced input -// shape and the 0-2-1 shape or the reduced output shape respectively. The -// original input and output shapes are called the unreduced input and output -// shapes. - +// the transposed shape the output shape. The logical view of the input/output +// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized +// shapes. The original input/output shapes are called unnormalized shapes. +// // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the -// reduced shape of `b` or the 0-2-1 shape. +// normalized shape of `b` or the 0-2-1 shape. absl::optional > FindTranspose021(const Shape& a, const Shape& b); -// Return the unreduced output index corresponding to the given reduced output -// index. -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b); +// A tile is a spatial subdivision of a tensor. We group tensor elements into +// tiles so that we can launch kernels to process the tensor elements in blocks +// of tiles. +// +// A kernel mapping scheme describes a method to partition the tensors accessed +// by an unnested HLO instruction into tiles and blocks of tiles, and the +// associated information to use hardware threads to process the tensor elements +// in blocks of tiles. +// +// Currently, there are two main use cases for a tiling scheme. First, we +// implement kernels with 0-2-1 memory transpose using shared memory to improve +// memory access pattern. Second, we implement reduction to contiguous +// dimensions in layout, with or without memory tranpsose, to achieve better +// memory access pattern as well as to reduce the need numbers of executed +// expensive instructions, such as thread synchronization related instructions +// and atomic operations. For both use cases, we can apply a normalization to +// the original tensors, to collapse contiguous dimensions for the same purpose +// and produce normlized three dimensional tensors. For this reason, the tiling +// scheme class only needs to handle normalized three dimensional tensors and +// two dimensional tiles. +// +// The current implementation of the class is somewhat NVIDIA GPU oriented. This +// situation can be improved when there is a need though. The idea of 0-2-1 +// transpose using shared memory can be found in the following CUDA algorithm in +// TensorFlow: https://goo.gl/MStRV6. +// +// We use a thread block to process a tile because we want to use the HW thread +// block synchronization primitives to synchronize the processing of all the +// elements in the same tile. A thread block can be viewed as a two dimensional +// array of threads, described by the number of threads for the Y and X +// dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of +// (tile_size_y, tile_size_x) as follows: each thread in the thread block +// processes one element in the tile so that all the threads in the thread block +// together process a subdivision of the tile that has the same dimension as the +// thread block array. Then the thread block moves on to process the next +// subdivision of the tile until the whole tile is processed. Therefore, each +// thread in the thread block processes +// tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. +// +// There are situations where we want a thread block to process multiple +// tiles. We can't group those tiles into a bigger tiles because we limit a tile +// to a two dimensional spatial subdivision of a tensor. For example, when we +// use tiling to implement reduction with tranpose, we want the partial sum +// produced by each thread to accumulate values for more elements before using +// shlf_down and atomic_add instructions for further reduction, to amortize the +// cost of such expensive instructions. The concept of tile block is introduced +// for this purpose. A tile block is a three dimensional array of tiles, of +// which some dimensions may be degenerated to only one tile. +class KernelMappingScheme { + public: + enum { DimZ = 0, DimY, DimX, DimTot }; + + public: + // dims_in_elems: the normalized tensor dimensions. + // req_block_sizes: the requested block size in number of tiles for each + // dimension. The actual block size is set to min(req_block_size, + // dims_in_number_of_blocks). + explicit KernelMappingScheme(absl::Span dims_in_elems, + int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); + + absl::Span GetDimensionsInElements() const { + return dims_in_elems_; + } + absl::Span GetDimensionsInTiles() const { + return dims_in_tiles_; + } + absl::Span GetDimensionsInBlocks() const { + return dims_in_blocks_; + } + + int64 GetNumberOfTilesInTotal() const { + return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); + } + int64 GetNumberOfTilesInOneBlock() const { + return absl::c_accumulate(block_sizes_, 1, std::multiplies()); + } + + int64 GetNumberOfBlocks() const { + return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); + } + + int64 GetTileSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return tile_sizes_[d]; + } + int64 GetTileSizeForDimensionX() const { + return GetTileSizeForDimension(DimX); + } + int64 GetTileSizeForDimensionY() const { + return GetTileSizeForDimension(DimY); + } + + absl::Span GetBlockSizes() const { return block_sizes_; } + + int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } + int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } + + int64 GetThreadsPerTile() const { + return GetNumberOfThreadsForDimensionX() * + GetNumberOfThreadsForDimensionY(); + } + + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); + // Returns the index for the first tile in the block with the given block + // index. + IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index); + // Returns the index for the first element in the tile with the given tile + // index. + IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index); + + std::tuple EmitThreadYXCoordinate( + llvm::Type* index_ty); + + IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape); + + llvm::GlobalVariable* GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name); + + private: + llvm::IRBuilder<>* b_; + // The number of elements in each dimension. + absl::Span dims_in_elems_; + + // The number of elements for each dimension of a tile. + std::vector tile_sizes_; + // The number of tiles in each dimension. It is computed from dims_in_elem_ + // and tile_sizes_. + std::vector dims_in_tiles_; + + // The number of tiles for each dimension of a tile block. + std::vector block_sizes_; + // The number of blocks in each dimension of a tile block. It is computed from + // dims_in_tile_ and block_sizes_. + std::vector dims_in_blocks_; + + // Number of threads used to process elements in the X direction of a tile. + int64 num_threads_x_; + // Number of threads used to process elements in the Y direction of a tile. + int64 num_threads_y_; +}; // A class to represent information for tiled parameters to support IR emission // for 021 transpose. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index df78726166eea953b57e72a5a5fc81ee246aca34..ceea24685af566e02340664f0a40c398c62b5ab0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -244,10 +244,11 @@ StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, int32 size_bytes) { - Shape shape; - TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes)); + ShapeProto shape_proto; + TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes)); + Shape shape(shape_proto); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return shape; + return std::move(shape); } llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index fd16af67fe99b4f440ad962b4b648a3b22c41dc6..e22c2173c271fc9571be1ddb0759d2b31562dc98 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -47,7 +47,8 @@ namespace { // Adds the inner comparison loop body where we compare elements. void EmitCompareLoopBody( int64 iteration_bound, PrimitiveType key_type, int64 num_values, - llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, + int64 iota_values_parameter_index, llvm::Value* element_pair_index, + int64 xor_mask, llvm::Type* index_type, std::function read_element, std::function write_element, @@ -139,34 +140,42 @@ void EmitCompareLoopBody( is_signed_comparison = false; } // If key2 < key1 - ksl.IfReturnVoid( - "is_smaller_than", + auto is_smaller_than = b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1), - [&]() { - // Swap key1 with key2. - write_element(0, current_keys_index, key2); - write_element(0, compare_keys_index, key1); - for (int64 i = 1; i <= num_values; ++i) { - // Also swap the values. - auto value1 = read_element(i, current_keys_index); - auto value2 = read_element(i, compare_keys_index); - write_element(i, current_keys_index, value2); - write_element(i, compare_keys_index, value1); - } - }); + compare_key2, compare_key1); + if (iota_values_parameter_index >= 0) { + auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); + auto key_index1 = + read_element(iota_values_parameter_index, current_keys_index); + auto key_index2 = + read_element(iota_values_parameter_index, compare_keys_index); + auto index_is_smaller_than = + b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); + is_smaller_than = b->CreateOr( + is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + } + ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); }); } -void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, - int64 dimension_to_sort, - int64 dimension_to_sort_bound, - PrimitiveType keys_type, - absl::Span xor_masks, - const std::vector& params, - const std::vector& param_shmem_buffers, - int64 tile_size, llvm::IRBuilder<>* b) { +void EmitTiledCompareLoop( + const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, + int64 dimension_to_sort_bound, PrimitiveType keys_type, + absl::Span xor_masks, const std::vector& params, + const std::vector& param_shmem_buffers, + int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); @@ -253,20 +262,22 @@ void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, RoundDownToNearest(dimension_to_sort_bound, tile_size))), [&]() { EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, - params.size() - 1, element_pair_index, xor_mask, + params.size() - 1, iota_values_parameter_index, + element_pair_index, xor_mask, tiled_keys_index.GetType(), read_element, write_element, b); }, [&]() { - EmitCompareLoopBody( - tile_size, keys_type, params.size() - 1, element_pair_index, - xor_mask, tiled_keys_index.GetType(), read_element, - write_element, b, /*needs_bounds_checks=*/false); + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), + read_element, write_element, b, + /*needs_bounds_checks=*/false); }); } else { EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - element_pair_index, xor_mask, - tiled_keys_index.GetType(), read_element, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, write_element, b, /*needs_bounds_checks=*/false); } // Wait until all comparisons have happened. @@ -296,6 +307,7 @@ void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, + int64 iota_values_parameter_index, absl::string_view name, absl::Span xor_masks, llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, @@ -367,8 +379,8 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, if (xor_masks.size() > 1) { EmitTiledCompareLoop(keys_index, dimension_to_sort, dimension_to_sort_bound, keys_shape.element_type(), - xor_masks, params, param_shmem_buffers, tile_size, - b); + xor_masks, params, param_shmem_buffers, + iota_values_parameter_index, tile_size, b); } else { auto read_element = [&](int64 operand, llvm::Value* index) { keys_index[dimension_to_sort] = index; @@ -380,9 +392,10 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, params[operand].EmitWriteArrayElement(keys_index, value, b); }; EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), - values_arrays.size(), tiles_index[rank - 1], - xor_masks[0], tiles_index.GetType(), read_element, - write_element, b); + values_arrays.size(), iota_values_parameter_index, + tiles_index[rank - 1], xor_masks[0], + tiles_index.GetType(), read_element, write_element, + b); } return Status::OK(); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 556a217322d373ffd5e816dcf35888b546806633..685f9383acba416f51681270e4037d56abb4b6ea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -31,9 +31,12 @@ namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This // implements the inner loop of BitonicSort. It is assumed that 'xor_masks' -// contains only powers of 2, or values 2^k - 1 (k > 0). +// contains only powers of 2, or values 2^k - 1 (k > 0). If +// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand +// that is a iota and can be used to make the sorting stable. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, + int64 iota_values_parameter_index, absl::string_view name, absl::Span xor_masks, llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index cca37556173bb95ef062b59ab0a4bf9ca7c496fe..6c89700983363fec46c41b5430c6eab6b366a1b6 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -96,44 +96,18 @@ ExecutionOptions CreateExecutionOptions( const ExecutableBuildOptions& build_options, const ProgramShape* program_shape) { ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.hlo_profile().has_value()) { - execution_options.mutable_debug_options()->set_xla_hlo_profile( - *build_options.hlo_profile()); - } - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); - } - if (build_options.dump_optimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_optimized_hlo_proto_to( - build_options.dump_optimized_hlo_proto_to().value()); - } - if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_unoptimized_hlo_proto_to( - build_options.dump_unoptimized_hlo_proto_to().value()); - } - if (build_options.dump_per_pass_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_per_pass_hlo_proto_to( - build_options.dump_per_pass_hlo_proto_to().value()); + if (build_options.has_debug_options()) { + *execution_options.mutable_debug_options() = build_options.debug_options(); } if (build_options.result_layout() != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); + build_options.result_layout()->ToProto(); } else { + Shape result_shape(program_shape->result()); + LayoutUtil::SetToDefaultLayout(&result_shape); *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + result_shape.ToProto(); } - - for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { - execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( - disabled_pass); - } - return execution_options; } @@ -145,7 +119,7 @@ StatusOr> LocalService::CompileExecutable( const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_host_program_shape()); - const ProgramShape& program_shape = proto.host_program_shape(); + ProgramShape program_shape(proto.host_program_shape()); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { @@ -220,4 +194,10 @@ StatusOr LocalService::GlobalDataToShapedBuffer( return buffers[replica_number]; } +StatusOr LocalService::RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag) { + return allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), tag); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 3b4f0b50832d6d2b64528ffb63eb5c7375396aec..f56ba32b04b9bf3aba75654bdb98887ad22e6791 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -63,6 +63,11 @@ class LocalService : public Service { StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index ec52a24d782a44fda961feab3230886072e755c7..972a5b9ced0d84387ef8308efe2a7aff7317d047 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -113,6 +113,13 @@ Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand and does not + // create buffers. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) { // The top-level buffer (index={}) for kCopy is newly created, but all other // buffers (in the case of a tuple shape) come from the operand diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 81f524d84a8091e1fff13dc7c55b401143a02753..7ffca943d0f7805ad4420343fcdbf860415c4c40 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -64,6 +64,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 6152cdc6099a182f1ed98f9501613e0aa123cdbb..c35f72699bfe90f7b8021916c0f81d5e1926ff4c 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -44,32 +45,48 @@ namespace xla { // // This pattern will match Add instructions whose first operand is a constant. // -// Each pattern type has the following modifiers: +// Each pattern type has the following modifiers, which are described where +// nontrivial. // // Op(): -// - WithName: match operations with the given name -// - WithOpcode: match operations with the given opcode -// - WithShape: match operations whose shape matches the given pattern -// - WithOperand: match operations whose operand matches the given pattern +// - Is: is the given HloInstruction* (i.e. pointer equality) +// - WithName +// - WithOpcode +// - WithoutOpcode: anything other than the given opcode +// - WithShape: instr's shape matches the given pattern +// - WithShapeEqualTo: instr's shape is equal to the given Shape +// - WithShapeCompatibleTo: instr's shape is compatible with the given Shape +// - WithNumOperands +// - WithOperand: operand at the given index matches the given pattern +// - IsConstant +// - IsNonConstant +// - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value, +// e.g. IsConstantScalar() or IsConstantScalar(42). +// - WithFusionKind +// - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): -// - EqualTo: matches shapes that are equal to the argument -// - CompatibleTo: matches shapes that are compatible to the argument -// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes -// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format -// - WithLayout: match shapes whose layout matches the given pattern -// - WithLayoutEqualTo: matches shapes whose layouts equal the argument -// - WithSubshape: matches tuple shapes whose subshape matches the given -// pattern -// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument -// - WithElementType: matches array/scalar shapes with the given element -// type -// - WithRank: matches array/scalar types with the given rank +// - EqualTo +// - CompatibleTo +// - IsScalar/IsEffectiveScalar/IsArray/IsTuple +// - IsDenseArray/IsSparseArray +// - WithLayout: layout shape's layout matches the given pattern (e.g. +// Layout().WithDenseFormat()) +// - WithLayoutEqualTo: shape's layout equals the argument (i.e. another +// Layout, but not the result of Layout().foo()) +// - WithSubshape: shape is a tuple whose subshape matches the given pattern +// (e.g. Shape().IsScalar()). +// - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg +// (i.e. another Shape, but not the result of Shape().foo()) +// - WithElementType: shape is an array/scalar with the given elem type +// - WithRank: shape is an array/scalar with the given rank // // Layout(): -// - EqualTo: matches layouts that are equal to the argument -// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse -// format +// - EqualTo +// - WithDenseFormat/WithSparseFormat // // Op(), Shape(), and Layout() may be passed an argument of type // HloInstruction**, Shape**, or Layout**, respectively, or const versions of @@ -82,53 +99,55 @@ namespace xla { // CHECK(Match(foo, // match::Op().WithOperand(0, match::Op(&matched_operand)))); // -// Helpers are provided for common nullary, unary, binary, and ternary -// instructions. These helpers can be called with no arguments, in which case -// they will match any instruction matching the opcode. They may also be called -// with matches for the operands and with an optional capture. (The capture must -// be the first argument.) Some examples of these helpers and their equivalents -// are provided below. -// +// Helpers are provided for most HLO instructions. These helpers can be called +// with no arguments, in which case they will match any instruction matching the +// opcode. They may also be called with matches for the operands and with an +// optional capture. (The capture must be the first argument.) Some examples of +// these helpers and their equivalents are provided below. + // Example nullary instruction: -// Param() == Op().WithOpcode(HloOpcode::kParam) -// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) +// Parameter() == Op().WithOpcode(HloOpcode::kParameter) +// Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter) // // Example unary instruction: -// Abs() == Op().WithOpcode(HloOpcode::kAbs) -// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&a))) -// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&b)) +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) +// +// Commutative binary instructions have a special form that accepts either order +// of args, e.g.: +// +// AddAnyOrder(Parameter(1), Abs()) == +// Op().WithOpcode(HloOpcode::kAdd) +// .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs()); // -// Example binary instruction: -// Add() == Op().WithOpcode(HloOpcode::kAdd) -// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) +// MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`. // -// Example ternary instruction: -// Clamp() == Op().WithOpcode(HloOpcode::kClamp) -// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// .WithOperand(2, Op(&c)) -// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) -// .WithOperand(2, Op(&d)) +// The following additional helpers are provided. In all cases, `&a` is +// optional. // +// ConstantScalar(&a) == Op(&a).IsConstantScalar(); +// ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v); +// ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar(); +// ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v) +// NonConstant(&a) == Op(&a).IsNonConstant() +// GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index) +// .WithOperand(0, b); +// Parameter(&a, n) == Op(&a).WithParameterNum(n); struct MatchOption { // If true, actually capture matched item into the user pointer. bool capture; + + // An explanation for why we failed to match is streamed here, if not-null. + std::ostream* explain_os; }; template bool Match(Value* value, const Pattern& pattern, - MatchOption option = {/*.capture=*/true}) { + MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) { if (option.capture) { auto new_option = option; new_option.capture = false; @@ -143,6 +162,77 @@ namespace match { namespace detail { +// Macro for streaming to option.explain_os if it's not null. +// +// EXPLAIN << "value of foo(): " << foo() +// +#pragma push_macro("EXPLAIN") +#define EXPLAIN \ + if (option.explain_os) *option.explain_os + +// kIndentInc is the additional number of spaces that we indent by when we +// increase the indent "by one". +enum { + kIndentInc = 2, +}; + +// Writes a newline and then `indent` spaces. +// +// We follow an unintuitive convention in this file's pretty-printers: Indents +// are performed by the caller, not the callee. For example, if you want to +// print +// +// foo: +// - bar +// +// you'd do: +// +// Foo::DescribeTo(std::ostream* os, int64 indent) { +// *os << "foo:"; +// Indent(os, indent) // Create a newline at the *current* indent level. +// *os << " - "; +// bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3. +// } +// +// Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; } +// +// Notice that Bar::DescribeTo() does not call Indent; the indenting is +// performed by Foo. This convention allows the caller to decide whether a +// matcher is preceded by a newline, which is important e.g. for the AllOf +// matcher. +// +// (Incidentally, indenting in Match's explanations is handled differently. +// Indents are a common case in DescribeTo [we're printing a whole tree], but +// they're a special case in Match [we're printing only a path through the tree +// that encounters a failing node]. Indents in Match only appear when we +// encounter a failing disjunction, so we just handle them as a special case +// there.) +inline void Indent(std::ostream* os, int64 indent) { + *os << "\n"; + for (int64 i = 0; i < indent; ++i) { + *os << " "; + } +} + +// SFINAE template that determines whether T declares a static member +// kIsTrivialMatcher. +// +// Trivial matchers get special treatment. For example, when printing +// a conjunction of matchers, we don't print "and" after a trivial matcher. This +// yields e.g. +// "a shape compatible with f32[1,2]" +// rather than +// "a shape AND compatible with f32[1,2]" +template +struct IsTrivialMatcher { + static constexpr bool value = false; +}; +template +struct IsTrivialMatcher::type> { + static constexpr bool value = true; +}; + template class AllOfPattern { public: @@ -162,10 +252,19 @@ class AllOfPattern { return matched; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + DescribeToImpl(os, std::integral_constant(), indent); + } + + // Accessor for patterns_. Please don't use this outside of this file. + const std::tuple& patterns() const { return patterns_; } + private: template bool MatchImpl(ItemType* item, MatchOption option, std::integral_constant) const { + // We don't need to do any EXPLAINing here; it's all correctly handled by + // our sub-matchers (if any fail). return std::get(patterns_).Match(item, option) && MatchImpl(item, option, std::integral_constant()); } @@ -176,6 +275,73 @@ class AllOfPattern { return true; } + // Pretty-printing a conjunction has some special cases to make it easy to + // read in the simple (common) case. + // + // If sizeof...(Patterns) == 1, prints as e.g. + // + // a shape + // + // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a + // shape") prints as + // + // a shape compatible with f32[1,2] + // + // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as + // + // a shape: + // * compatible with f32[1,2] AND + // * that represents a scalar + // + // Otherwise prints as: + // + // all of: + // * foo AND + // * bar + // + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + constexpr bool first_is_trivial = + IsTrivialMatcher(patterns_))>::type>::value; + constexpr bool is_last = index == sizeof...(Patterns) - 1; + const auto& submatcher = std::get(patterns_); + + auto print_bulleted_item = [&] { + *os << " * "; + submatcher.DescribeTo(os, indent + 3); + if (!is_last) { + *os << " AND"; + Indent(os, indent); + } + }; + + if (index == 0) { + if (first_is_trivial || is_last) { + submatcher.DescribeTo(os, indent + kIndentInc); + if (sizeof...(Patterns) > 2) { + *os << ":"; + Indent(os, indent); + } + } else { + *os << "all of:"; + Indent(os, indent); + print_bulleted_item(); + } + } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) { + *os << " "; + submatcher.DescribeTo(os, indent); + } else { + print_bulleted_item(); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -183,10 +349,6 @@ class AllOfPattern { // Returns a pattern that represents the conjunction of all input patterns. All // patterns need to match in order to have the AllOf pattern match. -// -// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is -// not AllOf. We might want to flatten the AllOf type structure if the -// C++ compile error message gets annoying. template detail::AllOfPattern::type, Patterns...> AllOf( const Patterns&... patterns) { @@ -194,6 +356,25 @@ detail::AllOfPattern::type, Patterns...> AllOf( Patterns...>(patterns...); } +// AllOf, X, Y, ...> => AllOf. +// +// This transformation is necessary for good pretty-printing. +template +detail::AllOfPattern::type, InnerPs..., + OuterPs...> +AllOf(const detail::AllOfPattern& inner_p, + const OuterPs&... outer_ps) { + // Invoke constructor of AllOfPattern. + auto make_all_of = [](const InnerPs&... inner_ps, + const OuterPs&... outer_ps) { + return detail::AllOfPattern::type, + InnerPs..., OuterPs...>(inner_ps..., + outer_ps...); + }; + return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(), + std::make_tuple(outer_ps...))); +} + namespace detail { template @@ -204,8 +385,18 @@ class LayoutPattern; class LayoutPatternBaseImpl { public: bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout != nullptr; + if (layout == nullptr) { + EXPLAIN << "Layout is null"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a layout"; } + + static constexpr bool kIsTrivialMatcher = true; }; // A LayoutPattern implementation that matches only if the layout equals a @@ -216,7 +407,17 @@ class LayoutPatternEqualImpl { : layout_(layout) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return LayoutUtil::Equal(*layout_, *layout); + if (!LayoutUtil::Equal(*layout_, *layout)) { + EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout) + << " is not equal to expected " + << LayoutUtil::HumanString(*layout_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << LayoutUtil::HumanString(*layout_); } private: @@ -230,7 +431,16 @@ class LayoutPatternFormatImpl { explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout->format() == format_; + if (layout->format() != format_) { + EXPLAIN << "Layout has format " << Format_Name(layout->format()) + << " but expected " << Format_Name(format_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with format " << Format_Name(format_); } private: @@ -242,11 +452,13 @@ template class LayoutPattern { private: template - LayoutPattern> - AppendImpl(NewImpl new_impl) const { - return LayoutPattern>( - AllOf(impl_, std::move(new_impl)), matched_layout_); + auto AppendImpl(NewImpl new_impl) const + -> LayoutPattern(std::declval(), + std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return LayoutPattern(std::move(new_allof), + matched_layout_); } public: @@ -276,6 +488,10 @@ class LayoutPattern { return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Layout* layout) const @@ -306,19 +522,48 @@ class AnyOfPattern { explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} bool Match(const Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); } bool Match(Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "any of:"; + Indent(os, indent); + DescribeToImpl(os, std::integral_constant(), indent); } private: + template + bool MatchImpl(ItemType* item, MatchOption option) const { + // If we're generating an explanation, buffer it until we know we failed. + absl::optional explanation; + MatchOption new_option = option; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + bool rv = MatchRecursiveImpl(item, new_option, + std::integral_constant()); + if (!rv && option.explain_os) { + EXPLAIN << "None of the following matchers succeeded:"; + EXPLAIN << explanation->str(); + } + return rv; + } + template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl(ItemType* item, MatchOption option, + std::integral_constant) const { auto new_option = option; new_option.capture = false; + + absl::optional explanation; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + // Try to match the sub-pattern without capturing behavior. if (std::get(patterns_).Match(item, new_option)) { // Capture the branch. @@ -337,20 +582,46 @@ class AnyOfPattern { // AnyOf will be a runtime number indicate which sub-pattern is matched. // Then we run another pass to do captures only with the help of the // trace. - bool ret = std::get(patterns_).Match(item, option); - DCHECK(ret); + bool matched = std::get(patterns_).Match(item, option); + DCHECK(matched); } return true; } - return MatchImpl(item, option, std::integral_constant()); + if (option.explain_os) { + EXPLAIN << "\nMatcher #" << index + 1; + EXPLAIN << "\n - "; + std::get(patterns_).DescribeTo(option.explain_os, /*indent=*/3); + EXPLAIN << "\nfailed with"; + EXPLAIN << "\n - "; + EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}}); + } + return MatchRecursiveImpl(item, option, + std::integral_constant()); } template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl( + ItemType* item, MatchOption option, + std::integral_constant) const { return false; } + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + *os << " - "; + std::get(patterns_).DescribeTo(os, indent + 3); + if (index != sizeof...(Patterns) - 1) { + *os << " OR"; + Indent(os, indent); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -395,8 +666,17 @@ class ShapePattern; class ShapePatternBaseImpl { public: bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (shape == nullptr) { + EXPLAIN << "Shape is null"; + } return shape != nullptr; } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a shape"; + } + + static constexpr bool kIsTrivialMatcher = true; }; // A ShapePattern implementation that matches only if the shape equals a Shape @@ -407,7 +687,16 @@ class ShapePatternEqualImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Equal(*shape_, *shape); + if (!ShapeUtil::Equal(*shape_, *shape)) { + EXPLAIN << "Shape not equal to " + << ShapeUtil::HumanStringWithLayout(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_); } private: @@ -422,7 +711,16 @@ class ShapePatternCompatibleImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Compatible(*shape_, *shape); + if (!ShapeUtil::Compatible(*shape_, *shape)) { + EXPLAIN << "Shape not compatible with " + << ShapeUtil::HumanString(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "compatible with " << ShapeUtil::HumanString(*shape_); } private: @@ -437,7 +735,16 @@ class ShapePatternElementTypeImpl { : element_type_(element_type) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return shape->element_type() == element_type_; + if (shape->element_type() != element_type_) { + EXPLAIN << "Shape does not have element type " + << PrimitiveType_Name(element_type_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with element type " << PrimitiveType_Name(element_type_); } private: @@ -450,7 +757,15 @@ class ShapePatternIsScalarImpl { explicit constexpr ShapePatternIsScalarImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsScalar(*shape); + if (!ShapeUtil::IsScalar(*shape)) { + EXPLAIN << "Shape is not a scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a scalar"; } }; @@ -460,7 +775,15 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsArray(*shape); + if (!ShapeUtil::IsArray(*shape)) { + EXPLAIN << "Shape is not an array"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents an array"; } }; @@ -470,7 +793,34 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsTuple(*shape); + if (!ShapeUtil::IsTuple(*shape)) { + EXPLAIN << "Shape is not a tuple"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a tuple"; + } +}; + +// A ShapePattern implementation that matches only if the shape is an effective +// scalar. +class ShapePatternEffectiveScalarImpl { + public: + explicit constexpr ShapePatternEffectiveScalarImpl() {} + + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (!ShapeUtil::IsEffectiveScalar(*shape)) { + EXPLAIN << "Shape is not an effective scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that is an effective scalar"; } }; @@ -481,7 +831,23 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Rank(*shape) == rank_; + if (ShapeUtil::Rank(*shape) != rank_) { + if (rank_ == 0) { + EXPLAIN << "Shape is not a scalar"; + } else { + EXPLAIN << "Shape does not have rank " << rank_; + } + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (rank_ == 0) { + *os << "that is a scalar"; + } else { + *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : ""); + } } private: @@ -503,8 +869,21 @@ class ShapePatternLayoutImpl { } bool Match(Shape* shape, MatchOption option) const { - return LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout(), option); + if (!LayoutUtil::HasLayout(*shape)) { + EXPLAIN << "Shape does not have a layout"; + return false; + } + if (!layout_.Match(shape->mutable_layout(), option)) { + EXPLAIN << "\nin layout"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with"; + Indent(os, indent + kIndentInc); + layout_.DescribeTo(os, indent + kIndentInc); } private: @@ -522,17 +901,40 @@ class ShapePatternSubshapeImpl { : index_(index), subshape_(subshape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); + return MatchImpl(shape, option); } bool Match(::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), - option); + return MatchImpl(shape, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with subshape at index " << index_.ToString() << " which is"; + Indent(os, indent + kIndentInc); + subshape_.DescribeTo(os, indent + kIndentInc); } private: + Shape* GetSubshape(Shape* shape) const { + return ShapeUtil::GetMutableSubshape(shape, index_); + } + const Shape* GetSubshape(const Shape* shape) const { + return &ShapeUtil::GetSubshape(*shape, index_); + } + + template + bool MatchImpl(ShapeType* shape, MatchOption option) const { + if (!ShapeUtil::IndexIsValid(*shape, index_)) { + EXPLAIN << "No subshape at " << index_.ToString(); + return false; + } + if (!subshape_.Match(GetSubshape(shape), option)) { + EXPLAIN << "\nin subshape at " << index_.ToString(); + return false; + } + return true; + } + ShapeIndexView index_; ShapePattern subshape_; }; @@ -542,10 +944,12 @@ template class ShapePattern { private: template - ShapePattern> AppendImpl( - NewImpl new_impl) const { - return ShapePattern>( - AllOf(impl_, std::move(new_impl)), matched_shape_); + auto AppendImpl(NewImpl new_impl) const + -> ShapePattern(std::declval(), + std::move(new_impl)))> { + auto new_all_of = AllOf(impl_, std::move(new_impl)); + return ShapePattern(std::move(new_all_of), + matched_shape_); } public: @@ -560,6 +964,11 @@ class ShapePattern { } return true; } + if (shape) { + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); + } return false; } @@ -571,9 +980,16 @@ class ShapePattern { } return true; } + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + return impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Shape* shape) const @@ -612,6 +1028,11 @@ class ShapePattern { return AppendImpl(ShapePatternIsTupleImpl()); } + constexpr auto IsEffectiveScalar() const + -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { + return AppendImpl(ShapePatternEffectiveScalarImpl()); + } + // Modifies the pattern to match only if the shape has the given rank. constexpr auto WithRank(int64 rank) const -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { @@ -706,6 +1127,22 @@ Shape(::xla::Shape** matched_shape) { namespace detail { +// Overloads to get a const or non-const operand out of an instruction. +inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) { + return instr->mutable_operand(idx); +} +inline const HloInstruction* HloOperand(const HloInstruction* instr, + int64 idx) { + return instr->operand(idx); +} + +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -714,8 +1151,18 @@ class HloInstructionPattern; class HloInstructionPatternBaseImpl { public: bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst != nullptr; + if (inst == nullptr) { + EXPLAIN << "HloInstruction* is null"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "an HloInstruction"; } + + static constexpr bool kIsTrivialMatcher = true; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -726,13 +1173,44 @@ class HloInstructionPatternNameImpl { : name_(name) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->name() == name_; + if (inst->name() != name_) { + EXPLAIN << "HloInstruction not named \"" << name_ << "\""; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "named \"" << name_ << "\""; } private: absl::string_view name_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// equals a particular pointer. +class HloInstructionIsImpl { + public: + explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst != inst_) { + EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" + << InstToString(inst_) << ")"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; + } + + private: + const HloInstruction* inst_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. class HloInstructionPatternOpcodeImpl { @@ -742,7 +1220,25 @@ class HloInstructionPatternOpcodeImpl { : opcode_(opcode), invert_(invert) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return (invert_ ^ (inst->opcode() == opcode_)); + if (invert_ && inst->opcode() == opcode_) { + EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_) + << ", expected anything else"; + return false; + } + if (!invert_ && inst->opcode() != opcode_) { + EXPLAIN << "HloInstruction doesn't have opcode " + << HloOpcodeString(opcode_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (!invert_) { + *os << "with opcode " << HloOpcodeString(opcode_); + } else { + *os << "with any opcode other than " << HloOpcodeString(opcode_); + } } private: @@ -757,8 +1253,17 @@ class HloInstructionPatternNumOperandsImpl { explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands) : num_operands_(num_operands) {} - bool Match(const ::xla::HloInstruction* inst, MatchOption /*option*/) const { - return inst->operand_count() == num_operands_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->operand_count() != num_operands_) { + EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with " << num_operands_ << " operand" + << (num_operands_ != 1 ? "s" : ""); } private: @@ -775,11 +1280,25 @@ class HloInstructionPatternShapeImpl { : shape_(shape) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(&inst->shape(), option); + if (!shape_.Match(&inst->shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(inst->mutable_shape(), option); + if (!shape_.Match(inst->mutable_shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "outputting"; + Indent(os, indent + kIndentInc); + shape_.DescribeTo(os, indent + kIndentInc); } private: @@ -797,20 +1316,197 @@ class HloInstructionPatternOperandImpl { : operand_index_(operand_index), operand_(operand) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_), option); + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_), option); + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with operand " << operand_index_ << " which is:"; + Indent(os, indent + kIndentInc); + operand_.DescribeTo(os, indent + kIndentInc); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (operand_index_ >= inst->operand_count()) { + EXPLAIN << "desired operand index " << operand_index_ + << " is out of bounds"; + return false; + } + if (!operand_.Match(HloOperand(inst, operand_index_), option)) { + EXPLAIN << "\nin operand " << operand_index_; + return false; + } + return true; + } + int64 operand_index_; HloInstructionPattern operand_; }; +// Matches a binary instruction whose operands come in any order. +template +class HloInstructionPatternBinaryOperandsAnyOrderImpl { + public: + explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) + : op1_(op1), op2_(op2) {} + + bool Match(HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with two operands in either order:"; + Indent(os, indent); + *os << " - "; + op1_.DescribeTo(os, indent + 3); + Indent(os, indent); + *os << " - "; + op2_.DescribeTo(os, indent + 3); + } + + private: + HloInstruction* operand(HloInstruction* inst, int64 idx) const { + return inst->mutable_operand(idx); + } + const HloInstruction* operand(const HloInstruction* inst, int64 idx) const { + return inst->operand(idx); + } + + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + // We could implement this using AnyOf and AllOf matchers, but the templates + // get pretty difficult to debug, since any compile error herein becomes + // not-an-error via SFINAE. Also this way lets us give better messages on + // failure. + if (inst->operand_count() != 2) { + EXPLAIN << "HloInstruction did not have two operands"; + return false; + } + + // If we're not generating explanations, this is pretty simple. + if (!option.explain_os) { + auto try_match = [&](int64 idx1, int64 idx2) { + MatchOption new_option = option; + new_option.capture = false; + if (op1_.Match(operand(inst, idx1), new_option) && + op2_.Match(operand(inst, idx2), new_option)) { + if (option.capture) { + bool matched = op1_.Match(operand(inst, idx1), option) && + op2_.Match(operand(inst, idx2), option); + DCHECK(matched); + } + return true; + } + return false; + }; + return try_match(0, 1) || try_match(1, 0); + } + + // If we are generating explanations, we have some work to do in order to + // generate a helpful error. + // + // First, try all four operand/matcher combinations, recording the + // failure explanations separately from option.explain_os. matches[i][j] + // tells us if matcher_i matches operand j. + bool matches[/*matcher*/ 2][/*operand*/ 2]; + std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2]; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + MatchOption new_option = option; + new_option.capture = false; + new_option.explain_os = &explanations[i][j]; + matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option) + : op2_.Match(operand(inst, j), new_option); + } + } + + // Check if the match succeeded. + for (int i = 0; i < 2; ++i) { + if (matches[0][i] && matches[1][(i + 1) % 2]) { + // Rerun the matches with capture enabled if necessary. + if (option.capture) { + auto* operand1 = operand(inst, i); + auto* operand2 = operand(inst, (i + 1) % 2); + bool matched = + op1_.Match(operand1, option) && op2_.Match(operand2, option); + DCHECK(matched); + } + return true; + } + } + + auto describe_matcher = [&](int matcher_idx) { + EXPLAIN << "\n - "; + if (matcher_idx == 0) { + op1_.DescribeTo(option.explain_os, /*indent=*/3); + } else { + CHECK_EQ(matcher_idx, 1); + op2_.DescribeTo(option.explain_os, /*indent=*/3); + } + for (int i = 0; i < 2; ++i) { + if (matches[matcher_idx][/*operand*/ i]) { + continue; + } + EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n"; + EXPLAIN << " - "; + EXPLAIN << absl::StrReplaceAll( + explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}}); + } + }; + + // If we failed to match, one of the following is true: + // 1. op1 (op2) matches neither LHS nor RHS, or + // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS). + // We print different explanations depending on which case we're in. + + // Case 1. + bool wrote_explanation = false; + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (!matches[i][0] && !matches[i][1]) { + EXPLAIN << "HloInstruction's operands (ignoring order) did not match " + << (i == 0 ? "first" : "second") << " matcher. Specifically,"; + describe_matcher(i); + wrote_explanation = true; + } + } + + // Case 2. + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (matches[/*matcher*/ 0][/*operand*/ i] && + matches[/*matcher*/ 1][/*operand*/ i]) { + CHECK(!matches[0][(i + 1) % 2]); + CHECK(!matches[1][(i + 1) % 2]); + CHECK(!wrote_explanation); + EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS") + << " operand did not match either of the two matchers. " + "Specifically,"; + describe_matcher(0); + EXPLAIN << "\nand"; + describe_matcher(1); + wrote_explanation = true; + } + } + + CHECK(wrote_explanation); + return false; + } + + HloInstructionPattern op1_; + HloInstructionPattern op2_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. class HloInstructionPatternFusionKindImpl { @@ -820,14 +1516,32 @@ class HloInstructionPatternFusionKindImpl { : kind_(kind) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with fusion kind " << ToString(kind_); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kFusion) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_) + << "; it's not a fusion"; + return false; + } + if (inst->fusion_kind() != kind_) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_); + return false; + } + return true; + } + ::xla::HloInstruction::FusionKind kind_; }; @@ -839,47 +1553,211 @@ class HloInstructionPatternTupleIndexImpl { : tuple_index_(tuple_index) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a GTE with index " << tuple_index_; } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kGetTupleElement) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_ + << "; it's not a GTE at all"; + return false; + } + if (inst->tuple_index() != tuple_index_) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_; + return false; + } + return true; + } + int64 tuple_index_; }; -template -class HloPredicatePatternImpl { +class HloInstructionPatternParameterNumImpl { public: - explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num) + : parameter_num_(parameter_num) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } - bool Match(const ItemType* item, MatchOption option) const { - return pred_(item); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); } - bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is parameter " << parameter_num_; + } private: - Predicate pred_; + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kParameter || + inst->parameter_number() != parameter_num_) { + EXPLAIN << "HloInstruction is not parameter " << parameter_num_; + return false; + } + return true; + } + + int64 parameter_num_; }; -struct PatternFriend; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + +// Matches a constant scalar or effective scalar, optionally with a given value. +template +class HloConstantScalarImpl { + public: + explicit constexpr HloConstantScalarImpl(bool match_effective_scalar) + : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {} + + constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar) + : val_(val), match_effective_scalar_(match_effective_scalar) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a constant " + << (match_effective_scalar_ ? "effective " : "") << "scalar"; + if (val_.has_value()) { + *os << " with value " << *val_; + } + } + + private: + template + bool MatchImpl(InstTy* inst, MatchOption option) const { + const auto* const_inst = DynCast(inst); + if (!const_inst) { + EXPLAIN << "HloInstruction is not a constant"; + return false; + } + if (match_effective_scalar_ && + !ShapeUtil::IsEffectiveScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not an effective scalar"; + return false; + } + if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not a scalar"; + return false; + } + if (!val_.has_value()) { + return true; + } + + // Check that literal == static_cast(val) and + // val == static_cast(literal). This is sufficient to ensure that + // the two constant scalars are actually "equal". + auto val_literal = LiteralUtil::CreateR0(*val_); + auto literal_r0_or = const_inst->literal().Reshape({}); + auto val_as_literal_ty_or = + val_literal.Convert(const_inst->shape().element_type()); + if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) { + EXPLAIN << "could not construct relevant Literals (how did this happen?)"; + return false; + } + auto literal_r0 = std::move(literal_r0_or).ValueOrDie(); + auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie(); + auto literal_r0_as_val_ty_or = + literal_r0.Convert(val_literal.shape().element_type()); + bool rv = literal_r0_as_val_ty_or.ok() && // + literal_r0_as_val_ty_or.ValueOrDie() == val_literal && + literal_r0 == val_as_literal_ty; + if (!rv) { + EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + << " did not match expected value " << *val_; + } + return rv; + } + + absl::optional val_; + bool match_effective_scalar_; +}; // A pattern that matches HloInstructions. template class HloInstructionPattern { private: template - HloInstructionPattern> - AppendImpl(NewImpl new_impl) const { - return HloInstructionPattern< - HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( - AllOf(impl_, std::move(new_impl)), matched_inst_); + auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< + HloInstructionType, decltype(AllOf( + std::declval(), std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return HloInstructionPattern( + std::move(new_allof), matched_inst_); } public: @@ -895,6 +1773,9 @@ class HloInstructionPattern { } return true; } + if (inst != nullptr) { + EXPLAIN << "\nin " << InstToString(inst); + } return false; } @@ -906,6 +1787,7 @@ class HloInstructionPattern { } return true; } + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -935,12 +1817,47 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } + constexpr auto Is(const HloInstruction* instr) const + -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { + return AppendImpl(HloInstructionIsImpl(instr)); + } + // Modifies the pattern to match only if the instruction is a constant. constexpr auto IsConstant() const -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } + constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false)); + } + + // This does not check that T has the same type as the instruction, so e.g. + // IsConstantScalar(1.0) may match a constant of shape int32[]. + template + constexpr auto IsConstantScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/false)); + } + + constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true)); + } + + template + constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/true)); + } + // Modifies the pattern to match only if the instruction is not a constant. constexpr auto IsNonConstant() const -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { @@ -957,6 +1874,22 @@ class HloInstructionPattern { HloInstructionPatternShapeImpl(shape)); } + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().EqualTo(shape))) { + return WithShape(Shape().EqualTo(shape)); + } + + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { + return WithShape(Shape().CompatibleTo(shape)); + } + // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template @@ -971,6 +1904,20 @@ class HloInstructionPattern { operand_index, operand)); } + template + constexpr auto WithBinaryOperandsAnyOrder( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) const + -> decltype(this->AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, + op2))) { + return AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); + } + // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const @@ -985,17 +1932,34 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } - private: - template - constexpr auto WithPredicate(Predicate pred) const -> decltype( - this->AppendImpl(HloPredicatePatternImpl( - std::move(pred)))) { - return AppendImpl( - HloPredicatePatternImpl(std::move(pred))); + // Modifies the pattern to match only if the instruction is a parameter + // with the given parameter number. + constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( + this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { + return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } - friend struct PatternFriend; + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + + private: Impl impl_; HloInstructionType** matched_inst_; }; @@ -1036,6 +2000,7 @@ Op(::xla::HloInstruction** matched_inst) { XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) XLA_NULLOP_PATTERN(Iota) +XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1067,8 +2032,10 @@ XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(Broadcast) XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(CrossReplicaSum) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) @@ -1088,6 +2055,7 @@ XLA_UNOP_PATTERN(Reverse) XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Slice) XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) @@ -1125,25 +2093,32 @@ XLA_UNOP_PATTERN(Transpose) #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ XLA_BINOP_PATTERN(NAME) \ \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ - return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ - } \ - \ template \ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ Rhs&& rhs) \ - ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs))) { \ - return AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs)); \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -1155,7 +2130,9 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) +XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(ReduceWindow) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) @@ -1202,6 +2179,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) .WithOperand(2, std::forward(arg2)); \ } XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN @@ -1254,32 +2232,12 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, // We could implement all ops as "variadic" ops, but it would make the // already-bad compile errors even worse. +XLA_VARIADIC_OP_PATTERN(AfterAll); XLA_VARIADIC_OP_PATTERN(Concatenate); +XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(Map) XLA_VARIADIC_OP_PATTERN(Reduce); - -namespace detail { -struct PatternFriend { - template - static auto ConstantScalar(T constant) -> decltype( - Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate( - std::declval>())) { - std::function pred = - [constant](const HloInstruction* instr) { - const auto& literal = Cast(instr)->literal(); - auto status_or_const = LiteralUtil::CreateR0(constant).Convert( - literal.shape().element_type()); - return status_or_const.ok() && - literal == status_or_const.ConsumeValueOrDie(); - }; - - return Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate(std::move(pred)); - } -}; -} // namespace detail +XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { @@ -1318,14 +2276,71 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } -template -inline auto ConstantScalar(T constant) - -> decltype(detail::PatternFriend::ConstantScalar(constant)) { - return detail::PatternFriend::ConstantScalar(constant); +// Add overloads for Parameter which take an int64 specifying the parameter +// number. +inline auto Parameter(int64 parameter_num) -> decltype( + Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { + return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); +} +template +inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num); +} + +inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantScalar(); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantScalar(); +} + +template +inline auto ConstantScalar(ScalarTy val) + -> decltype(Op().IsConstantScalar(val)) { + return Op().IsConstantScalar(val); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) + -> decltype(Op(matched_inst).IsConstantScalar(val)) { + return Op(matched_inst).IsConstantScalar(val); +} + +inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(ScalarTy val) + -> decltype(Op().IsConstantEffectiveScalar(val)) { + return Op().IsConstantEffectiveScalar(val); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, + ScalarTy val) + -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { + return Op(matched_inst).IsConstantEffectiveScalar(val); } } // namespace match } // namespace xla +#undef EXPLAIN +#pragma pop_macro("EXPLAIN") #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock.h b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h new file mode 100644 index 0000000000000000000000000000000000000000..8fe2d10a11b5b2d26ee222c63e0db2d55e361d12 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ + +#include +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +namespace pattern_matcher_gmock_detail { +template +class GmockMatcher { + public: + explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {} + + // In service of better error messages, list out the overloads explicitly + // rather than just using a template. gMock's polymorphism plus + // pattern_matcher yields some pretty gnarly stuff. + bool MatchAndExplain(const Layout& l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&l, listener); + } + bool MatchAndExplain(const Layout* l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(l, listener); + } + + bool MatchAndExplain(const Shape& s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&s, listener); + } + bool MatchAndExplain(const Shape* s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(s, listener); + } + + bool MatchAndExplain(const HloInstruction& instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&instr, listener); + } + bool MatchAndExplain(const HloInstruction* instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(instr, listener); + } + + void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is NOT: "; + DescribeTo(os); + } + + private: + template + bool MatchAndExplainImpl(const T* t, + ::testing::MatchResultListener* listener) const { + MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()}; + return Match(t, pattern_, options); + } + + Pattern pattern_; +}; +} // namespace pattern_matcher_gmock_detail + +template +::testing::PolymorphicMatcher< + pattern_matcher_gmock_detail::GmockMatcher> +GmockMatch(Pattern&& p) { + return ::testing::MakePolymorphicMatcher( + pattern_matcher_gmock_detail::GmockMatcher( + std::forward(p))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ca2fb05c1f7ef093c58237cf21fbc7c813a592a --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +namespace m = ::xla::match; +using ::testing::Eq; +using ::testing::Not; + +template +string Describe(const ::testing::Matcher& m) { + std::stringstream ss; + m.DescribeTo(&ss); + return ss.str(); +} + +template +string Explain( + const MatchedTy& val, + const ::testing::Matcher::type>& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(val, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(val, &listener)); + return listener.str(); +} + +// This file tests the GmockMatch function. The actual explanation and +// description returned by matchers is tested in pattern_matchers_test. +TEST(PatternMatcherGmock, MatchShape) { + Shape s = ShapeUtil::MakeShape(F32, {10, 100}); + // You can pass const Shape& or a const Shape*. + EXPECT_THAT(s, GmockMatch(m::Shape())); + EXPECT_THAT(&s, Not(GmockMatch(m::Shape().WithElementType(F16)))); + EXPECT_THAT(Describe(GmockMatch(m::Shape().IsArray())), + "a shape that represents an array"); +} + +TEST(PatternMatcherGmock, MatchLayout) { + Layout l = LayoutUtil::MakeLayout({0, 1}); + EXPECT_THAT(l, GmockMatch(m::Layout())); + EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat()))); + EXPECT_THAT(Describe(GmockMatch(m::Layout().WithSparseFormat())), + "a layout with format SPARSE"); +} + +TEST(PatternMatchGmock, MatchInstruction) { + auto instr = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {42}), "p"); + EXPECT_THAT(instr.get(), GmockMatch(m::Parameter())); + EXPECT_THAT(*instr, GmockMatch(m::Parameter(0))); + EXPECT_THAT(*instr, Not(GmockMatch(m::Parameter(1)))); + EXPECT_THAT(Describe(GmockMatch(m::Parameter())), + "an HloInstruction with opcode parameter"); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 3f74273517aebfd6f2700a9275b92765e29f21cc..186ef0c7911a2724df810780e018f52586e3e6a8 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { +namespace m = match; + TEST(PatternMatcherTest, AddOp) { constexpr char kModuleStr[] = R"(HloModule two_plus_two_module ENTRY %two_plus_two_computation () -> f32[] { @@ -229,23 +233,74 @@ TEST(PatternMatcherTest, AnyOf) { } TEST(PatternMatcherTest, ConstantScalar) { - constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); - auto* root = hlo_module->entry_computation()->root_instruction(); - - EXPECT_TRUE(Match(root, match::ConstantScalar(42))); - EXPECT_FALSE(Match(root, match::ConstantScalar(41))); - EXPECT_FALSE(Match(root, match::ConstantScalar(0))); -} + using match::ConstantEffectiveScalar; + using match::ConstantScalar; + using match::Op; + using match::Tuple; -TEST(PatternMatcherTest, NoMatchConstantScalar) { constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + HloModule test_module + ENTRY test { + a = s32[] constant(1) + b = s32[1,1] constant(s32[1,1]{{2}}) + c = s32[1,2] constant(s32[1,2]{{2,2}}) + d = f32[] constant(1) + e = f32[] constant(1.25) + ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) + })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); - EXPECT_FALSE(Match(root, match::ConstantScalar(42))); + const HloInstruction* a = root->operand(0); + const HloInstruction* b = root->operand(1); + const HloInstruction* c = root->operand(2); + const HloInstruction* d = root->operand(3); + const HloInstruction* e = root->operand(4); + EXPECT_TRUE(Match(a, ConstantScalar())); + EXPECT_TRUE(Match(a, ConstantScalar(1))); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1))); + EXPECT_FALSE(Match(a, ConstantScalar(2))); + EXPECT_FALSE(Match(a, ConstantScalar(2.01))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01))); + + EXPECT_FALSE(Match(b, ConstantScalar())); + EXPECT_FALSE(Match(b, ConstantScalar(2))); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2))); + + EXPECT_FALSE(Match(c, ConstantScalar())); + EXPECT_FALSE(Match(c, ConstantScalar(2))); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar())); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2))); + + EXPECT_TRUE(Match(d, ConstantScalar(1))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1))); + EXPECT_TRUE(Match(d, ConstantScalar(1.0))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0))); + + EXPECT_TRUE(Match(e, ConstantScalar(1.25f))); + EXPECT_TRUE(Match(e, ConstantScalar(1.25))); + EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25))); + EXPECT_FALSE(Match(e, ConstantScalar(1))); + EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1))); + + const HloInstruction* instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1))); + EXPECT_EQ(instr, a); } TEST(PatternMatcherTest, MultiplyAnyOrder) { @@ -267,6 +322,15 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) { root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); EXPECT_TRUE(Match( root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); + + // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call + // e.g. IsNonConstant() on it. + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); + EXPECT_TRUE( + Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); } TEST(PatternMatcherTest, AnyOfShortCircuit) { @@ -315,14 +379,22 @@ TEST(PatternMatcherTest, AllOf) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); + auto f16_scalar = ShapeUtil::MakeShape(F16, {}); + auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar); + auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar); auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); - auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); ASSERT_TRUE(Match(root, scalar_pattern)); ASSERT_TRUE(Match(root, f16_pattern)); - EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); - EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + ASSERT_TRUE(Match(root, f16_compatible_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern, + f16_compatible_pattern))); + EXPECT_TRUE( + Match(root, AllOf(f16_pattern, f16_compatible_pattern, + scalar_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE(Match( + root, AllOf(Broadcast(Op()), f16_compatible_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), scalar_pattern))); } @@ -431,5 +503,433 @@ TEST(PatternMatcherTest, TestConcat) { Reshape(ConstantScalar(4))))); } +template +string Description(const Pattern& pattern) { + std::stringstream ss; + pattern.DescribeTo(&ss); + return ss.str(); +} + +template +string Explanation(Elem* elem, const Pattern& pattern) { + std::stringstream ss; + MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss}; + Match(elem, pattern, options); + return ss.str(); +} +template +string Explanation(const std::unique_ptr& elem, const Pattern& pattern) { + return Explanation(elem.get(), pattern); +} +template +string Explanation(const Elem& elem, const Pattern& pattern) { + return Explanation(&elem, pattern); +} + +// Helper macro for checking a pattern's description and the explanation printed +// when attempting to match (and presumably failing) on a given object. +// +// We use a macro rather than a function because we want good line numbers in +// errors. We use this rather than writing a helper that returns a pair of +// (description, explanation) and doing something like +// +// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...)); +// +// because EXPECT_EQ prints a unified diff if multiline string comparison fails, +// while EXPECT_THAT does not. This unified diff makes the errors much easier +// to read. +#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \ + expected_explanation) \ + do { \ + EXPECT_EQ(Description(pattern), (expected_desc)); \ + EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \ + } while (0) + +TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { + auto layout = LayoutUtil::MakeLayout({1, 2}); + auto layout2 = LayoutUtil::MakeLayout({2, 2}); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Layout(), + "a layout", "Layout is null"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout), + "a layout equal to {1,2}", + "Layout {2,2} is not equal to expected {1,2}"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(), + "a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE"); + EXPECT_DESC_AND_EXPLANATION(layout, + m::Layout().EqualTo(&layout).WithSparseFormat(), + "a layout:\n" + " * equal to {1,2} AND\n" + " * with format SPARSE", + "Layout has format DENSE but expected SPARSE"); +} + +TEST(PatternMatcherTest, ShapeDescribeToAndExplain) { + auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1}); + auto layout = shape.layout(); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Shape(), + "a shape", "Shape is null"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}), + m::Shape().CompatibleTo(&shape), + "a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16), + "a shape with element type F16", + "Shape does not have element type F16\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(), + "a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(), + "a shape that represents an array", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(), + "a shape that represents a tuple", + "Shape is not a tuple\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(), + "a shape that is an effective scalar", + "Shape is not an effective scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42), + "a shape that has 42 dimensions", + "Shape does not have rank 42\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0), + "a shape that is a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(), + "a shape:\n" + " * that has 1 dimension AND\n" + " * that represents an array", + "Shape does not have rank 1\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), + m::Shape().IsArray().WithRank(1), + "a shape:\n" + " * that represents an array AND\n" + " * that has 1 dimension", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().WithLayoutEqualTo(&layout), + "a shape with\n a layout equal to {0,1}", + "Layout {1,0} is not equal to expected {0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION( + shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()), + "a shape with\n a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeEqualTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape equal to f32[1,2]{0,1}", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeEqualTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeCompatibleTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape compatible with f32[1,2]", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeCompatibleTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}), + m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()), + "a shape with subshape at index {0,0} which is\n" + " a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}\n" + "in subshape at {0,0}\n" + "in ((f32[1,2]))"); +} + +std::unique_ptr SetName(absl::string_view name, + std::unique_ptr instr) { + instr->SetAndSanitizeName(string(name)); + return instr; +} + +TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { + std::unique_ptr iota = + SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), + /*iota_dimension=*/0)); + std::unique_ptr constant = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), + m::Op(), "an HloInstruction", + "HloInstruction* is null"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"), + "an HloInstruction named \"foo\"", + "HloInstruction not named \"foo\"\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd), + "an HloInstruction with opcode add", + "HloInstruction doesn't have opcode add\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().IsNonConstant(), + "an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42), + "an HloInstruction with 42 operands", + "HloInstruction doesn't have 42 operands\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()), + "an HloInstruction outputting\n" + " a shape that represents a tuple", + "Shape is not a tuple\n" + "in s32[42]{0}\n" + "in output shape\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)), + "an HloInstruction with operand 2 which is:\n" + " an HloInstruction with opcode add", + "desired operand index 2 is out of bounds\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}), + HloOpcode::kAdd, constant.get(), + constant.get())), + m::Op().WithOperand(1, m::Op().IsNonConstant()), + "an HloInstruction with operand 1 which is:\n" + " an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)\n" + "in operand 1\n" + "in a = s32[] add(s32[] c, s32[] c)"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop), + "an HloInstruction with fusion kind kLoop", + "HloInstruction does not have fusion kind kLoop; it's not a fusion\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithTupleIndex(42), + "an HloInstruction which is a GTE with index 42", + "HloInstruction is not a GTE with index 42; it's not a GTE at all\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(), + "an HloInstruction which is a constant scalar", + "HloInstruction is not a constant\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2}))), + m::Op().IsConstantEffectiveScalar(), + "an HloInstruction which is a constant effective scalar", + "HloInstruction is not an effective scalar\n" + "in c = s32[2]{0} constant({1, 2})"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))), + m::Op().IsConstantScalar(42), + "an HloInstruction which is a constant scalar with value 42", + "HloInstruction's constant value 10 did not match expected value 42\n" + "in c = s32[] constant(10)"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))), + m::Op().IsConstantEffectiveScalar(1.25), + "an HloInstruction which is a constant effective scalar with value 1.25", + "HloInstruction's constant value 2.25 did not match expected value 1.25\n" + "in c = f64[] constant(2.25)"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().Is(iota.get()), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), + absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" + "in c = s32[] constant(0)")); +} + +TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + SetName("b", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction named \"b\"\n" + " - an HloInstruction named \"bar\"", + "HloInstruction's operands (ignoring order) did not match second " + "matcher. Specifically,\n" + " - an HloInstruction named \"bar\"\n" + "does not match LHS:\n" + " - HloInstruction not named \"bar\"\n" + " in b = s32[] constant(0)\n" + "does not match RHS:\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)\n" + "in a = s32[] add(s32[] b, s32[] c)"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", + HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + HloInstruction::CreateParameter(0, scalar_s32, "p").get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction which is a constant scalar\n" + " - an HloInstruction with opcode constant", + "HloInstruction's LHS operand did not match either of the two matchers. " + "Specifically,\n" + " - an HloInstruction which is a constant scalar\n" + "does not match LHS:\n" + " - HloInstruction is not a constant\n" + " in p = s32[] parameter(0)\n" + "and\n" + " - an HloInstruction with opcode constant\n" + "does not match LHS:\n" + " - HloInstruction doesn't have opcode constant\n" + " in p = s32[] parameter(0)\n" + "in a = s32[] add(s32[] p, s32[] c)"); +} + +TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), + m::AnyOf(m::Op().WithName("foo"), + m::Op().WithName("bar")), + "any of:\n" + " - an HloInstruction named \"foo\" OR\n" + " - an HloInstruction named \"bar\"", + "None of the following matchers succeeded:\n" + "Matcher #1\n" + " - an HloInstruction named \"foo\"\n" + "failed with\n" + " - HloInstruction not named \"foo\"\n" + " in c = s32[] constant(0)\n" + "Matcher #2\n" + " - an HloInstruction named \"bar\"\n" + "failed with\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)"); +} + +TEST(PatternMatcherTest, Parameter) { + auto param = + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + auto non_param = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + EXPECT_FALSE(Match(param.get(), m::Parameter(0))); + EXPECT_TRUE(Match(param.get(), m::Parameter())); + EXPECT_TRUE(Match(param.get(), m::Parameter(1))); + EXPECT_FALSE(Match(non_param.get(), m::Parameter())); + EXPECT_FALSE(Match(non_param.get(), m::Parameter(1))); + + EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1), + "an HloInstruction:\n" + " * with opcode parameter AND\n" + " * which is parameter 1", + "HloInstruction doesn't have opcode parameter\n" + "in c = s32[] constant(0)"); + EXPECT_EQ(Explanation(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "p0"), + m::Parameter(1)), + "HloInstruction is not parameter 1\n" + "in p0 = f32[] parameter(0)"); +} + +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 16fa80d53e7dc3456b0dade8b92cf101b3e0a33d..efeec96571455d8a9e4b7837dd7286392c12f1a3 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -81,7 +81,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -111,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -140,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -173,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -205,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -242,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -295,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -321,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -348,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -376,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -402,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -438,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -485,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewUnverifiedModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 75f7413b3c303da620c2815c83e03324148c0961..5ec7fe2adedac2fc3d8a7588e853dba90e99006f 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -275,8 +276,8 @@ StatusOr> Service::CreateModuleConfig( } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { - const auto& shape_with_output_layout = - execution_options->shape_with_output_layout(); + const Shape shape_with_output_layout( + execution_options->shape_with_output_layout()); TF_RETURN_IF_ERROR( ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( @@ -658,9 +659,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().host_program_shape(), - replicated_arguments.front(), - request.execution_options())); + CreateModuleConfig( + ProgramShape{request.computation().host_program_shape()}, + replicated_arguments.front(), request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -745,9 +746,9 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%d) exceeds the number of available devices " - "on the target (%d)", - arg->device_count(), available_device_count); + "Requested logical device count (%d) with replica count (%d) exceeds " + "the number of available physical devices on the target (%d)", + arg->device_count(), replica_count, available_device_count); } for (int64 i = 0; i < arg->device_count(); ++i) { @@ -818,14 +819,17 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { "The compile request does not support multiple device handles."); } - std::vector argument_shapes; - absl::c_transform(arg->input_shape_with_layout(), - std::back_inserter(argument_shapes), - [](const Shape& shape) { return &shape; }); + std::vector argument_shapes; + argument_shapes.reserve(arg->input_shape_with_layout_size()); + std::vector argument_shape_ptrs; + for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { + argument_shapes.push_back(Shape(shape_proto)); + argument_shape_ptrs.push_back(&argument_shapes.back()); + } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(arg->computation().host_program_shape(), - argument_shapes, &arg->execution_options())); + CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, + argument_shape_ptrs, &arg->execution_options())); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -930,14 +934,14 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - const Shape* return_shape; + Shape return_shape; if (arg->has_shape_with_layout()) { - if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { + return_shape = Shape(arg->shape_with_layout()); + if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } - return_shape = &arg->shape_with_layout(); } else { - return_shape = &shaped_buffer->on_host_shape(); + return_shape = Shape(shaped_buffer->on_host_shape()); } TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( @@ -948,30 +952,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal.Relayout(*return_shape).ToProto(); + result_literal.Relayout(return_shape).ToProto(); } return Status::OK(); } -namespace { - -// Creates a clone of the given shaped buffer with the given device ordinal. The -// shape and DeviceMemoryBase values of the clone are identical to the original. -std::unique_ptr CloneShapedBufferOnDevice( - const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = absl::make_unique( - shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), - shaped_buffer.platform(), device_ordinal); - clone->buffers() = shaped_buffer.buffers(); - return clone; -} - -} // namespace - Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(Literal literal, @@ -1060,11 +1049,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), literal)); + executor, Shape(arg->shape_with_layout()), literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1087,7 +1076,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().host_program_shape(); + ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1118,7 +1107,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_host_shape(); + *result->mutable_shape() = buffer->on_host_shape().ToProto(); return Status::OK(); } @@ -1131,7 +1120,7 @@ Status Service::GetComputationGraphStats( return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().host_program_shape()); + HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()}); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 61a60ef9efa72f53fa2c6730ca297ddfe01c56ba..7e7282a737041458aed39b0054f901c23aa87d7a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -391,17 +391,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr ShapeInference::InferAfterAllShape( - absl::Span arg_shapes) { - for (const Shape* arg_shape : arg_shapes) { - if (arg_shape->element_type() != TOKEN) { - return InvalidArgument( - "Operands of token instructions must be TOKEN types."); - } - } - return ShapeUtil::MakeTokenShape(); -} - /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -1029,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); - result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); + result.mutable_tuple_shapes()->reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } @@ -2038,7 +2027,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dimension); } - return ShapeUtil::MakeShape(S64, {}); + // TODO(b/119580730): Remove this restriction when very large dimension size + // is needed. + if (shape.dimensions(dimension) > std::numeric_limits::max()) { + return InvalidArgument( + "GetDimensionSize's input shape is %s, the %dth dimension exceeds the " + "UINT_MAX limit.", + ShapeUtil::HumanString(shape), dimension); + } + + return ShapeUtil::MakeShape(U32, {}); } /* static */ StatusOr ShapeInference::InferSliceShape( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 31ef4b2e41078f87731a1eff58e37409a6004ba4..d94385a04d50baff8156570a09620fd458547936 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -232,13 +232,6 @@ class ShapeInference { static StatusOr InferConcatOpShape( absl::Span arg_shapes, int64 dimension); - // Infers the shape produced by a kAfterAll. Trivially this shape is always a - // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes - // and checking operand shapes. This method verifies that the operand shapes - // are all TOKENs. - static StatusOr InferAfterAllShape( - absl::Span arg_shapes); - // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 7a565bf076847a4a5f7c98635785c80d86df152d..17cdaa74fc328d156292f5af828d4222a9a01f1f 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewUnverifiedModule("fuse_with_constant_operands"); + auto module = CreateNewVerifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewUnverifiedModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 96f3055c98e0611dfe25517cb490014a6d1f7c76..50d51eaeb762e208004c1dae3dcc27503f3f94e9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -280,6 +280,13 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index bcfcb388f95b0bedb35a8c399e804034816867b3..0a1d5649d6d69fea12263e6986ce76af62615ec7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -252,6 +252,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10ef2d38fa21c3e93c270535bc99b2f76435337d..561762b5d424ed5f537665be9d67a81dc8bdd56e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -264,6 +264,22 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { UnorderedElementsAre(inner_tuple)); } +TEST_F(TuplePointsToAnalysisTest, AddDependency) { + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto add_dependency = builder.AddInstruction( + HloInstruction::CreateAddDependency(constant, token)); + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency); + EXPECT_EQ(1, points_to_set.size()); + EXPECT_FALSE(points_to_set.IsAmbiguous()); + EXPECT_TRUE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant}); +} + TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index b7c28bfac7889b788645360366d1419eb80e64de..41011176ffa91e885bc58364d1fb19617d3518ad 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -207,6 +208,37 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( continue; } + if (!hoist_size_inflating_ops_) { + // Check that hoisting the instruction doesn't cause a significant memory + // blow-up. LICM extends the live-range of the output of the hoisted + // instruction to be the entire while loop, which may be problematic on + // platforms where memory is limited. This can be especially harmful if + // the instruction has a significantly larger output than its input, e.g. + // kIota, kBroadcast or kConstant. + int64 input_size = 0, output_size = 0; + + for (auto* operand : instruction->operands()) { + ShapeUtil::ForEachSubshape( + operand->shape(), + [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + input_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + output_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + + if (output_size > input_size) { + continue; + } + } + auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || unhoisted_invariant_instructions.count(op) || diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 3031899f71e0fd77f20448d9d7489798af01615c..bd6232dc0a988775a0490abbf6125daad8476295 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -34,8 +34,14 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { // Setting `hoist_constants` to false can be help if LICM is run in the mid // level HLO pipeline because hoisting constants out of while loop bodies can // break optimizations like constant folding. - explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) - : hoist_constants_(hoist_constants) {} + // Setting `hoist_size_inflating_ops` to false will forbid hoisting + // instructions where the size of the output(s) is larger than the size of the + // input(s). This is useful on platforms on which it's important to prevent + // blow-ups in memory size. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false, + bool hoist_size_inflating_ops = true) + : hoist_constants_(hoist_constants), + hoist_size_inflating_ops_(hoist_size_inflating_ops) {} ~WhileLoopInvariantCodeMotion() override = default; absl::string_view name() const override { @@ -49,6 +55,7 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { HloInstruction* while_instr); bool hoist_constants_; + bool hoist_size_inflating_ops_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 046ccb2d3f29c2141ade5275d043875e3e278582..8e7c4bc8828552e197b41f874c070d496b85a382 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -570,5 +570,59 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { EXPECT_FALSE(simplified_loop); } +const char* const kInflatingTestCase = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[]) parameter(0) + iota = f32[1024, 1024] iota(), iota_dimension=0 + add = f32[1024, 1024] add(iota, iota) + constant = f32[] constant(1.0) + reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul + ROOT root = (f32[]) tuple(reduce) +} + +condition { + p_cond = (f32[]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param = f32[] parameter(0) + while_init = (f32[]) tuple(param) + ROOT while = (f32[]) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true, + /*hoist_size_inflating_ops=*/false) + .Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6f924a29d8a3ac60abe98efd2e82ae7343c7de47..d30f67dd8110b88166fe807762fb653190ec00bc 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -19,13 +19,17 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { +namespace m = match; using absl::optional; using hlo_query::ContainsInstrWithOpcode; @@ -302,6 +306,147 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return true; } +// Removes each loop parameter (i.e. member of the while loop tuple) that is a +// constant and is the same in the while loop body and the while loop init. +static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + + absl::flat_hash_set constant_tuple_indices; + const auto& while_shape = while_init->shape(); + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (init_elem->opcode() == HloOpcode::kConstant && + body_elem->opcode() == HloOpcode::kConstant && + init_elem->literal() == body_elem->literal()) { + constant_tuple_indices.insert(i); + } + } + + if (constant_tuple_indices.empty()) { + return false; + } + + // OK, we found some constant elements of the while parameter! Eliminate + // them. + std::vector new_while_shape_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + new_while_shape_elems.push_back(while_shape.tuple_shapes(i)); + } + } + Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + // Returns a new tuple without the elements of constant_tuple_indices. + auto remove_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), while_shape)); + + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, i))); + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + auto add_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + + std::vector tuple_elems; + int64 j = 0; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (constant_tuple_indices.count(i)) { + tuple_elems.push_back(while_init->mutable_operand(i)); + } else { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, j))); + ++j; + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Special case: constant_tuple_indices covers the whole while parameter, so + // the new while shape is the empty tuple. In this case, the value of the + // while loop is simply equal to the value of `init`. + // + // It's unfortunate to special-case this, but it's simpler than the + // alternative. The problem is that if our while parameter has no + // non-constant elems, the tuple returned by `add_constant_elems` won't depend + // on instr (the loop body/cond parameter), and therefore + // CloneWithReplacementPairs will *leave the parameter out entirely*, creating + // invalid HLO. + if (ShapeUtil::IsEmptyTuple(new_while_shape)) { + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); + return true; + } + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + remove_constant_elems( + add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, + add_constant_elems( + computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + add_new_instr(remove_constant_elems(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + // Tries to remove a while loop from the graph. // // - Loops with trip count of 0 can be replaced by the loop's "init" value. @@ -381,16 +526,14 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // performance by forcing us to copy constants. absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && instr->operand(0) == while_body_param && - ShapeUtil::IsScalar(instr->shape())) { - auto tuple_element = while_init->operand(i); - if (tuple_element->IsConstant()) { - VLOG(3) << "Found loop invariant tuple element " << i << " " - << tuple_element->ToString(); - index_to_constant[i] = tuple_element; - } + const HloInstruction* init_tuple_elem = nullptr; + if (Match(root_operands[i], + m::GetTupleElement(m::Op().Is(while_body_param), i) + .WithShape(m::Shape().IsScalar())) && + Match(while_init->operand(i), m::Constant(&init_tuple_elem))) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << init_tuple_elem->ToString(); + index_to_constant[i] = init_tuple_elem; } } @@ -519,14 +662,6 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { return false; } - // Cowardly refuse to perform this optimization in the presence of kDomain - // instructions, which may reference other instructions in the loop and - // therefore make this complicated. - if (ContainsInstrWithOpcode(while_body, {HloOpcode::kDomain}) || - ContainsInstrWithOpcode(while_cond, {HloOpcode::kDomain})) { - return false; - } - std::vector flattened_shape_elems; ShapeUtil::ForEachSubshape(while_shape, [&](const Shape& s, const ShapeIndex& /*index*/) { @@ -605,6 +740,243 @@ static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { return true; } +// Tries to merge loop induction variables of a given type. +// +// In this pass we're only concerned with elements of the loop's tuple that +// are effective-scalars of type `elem_ty`. Some terminology: +// +// - The trip counter is the first element of the loop's tuple that starts at +// 0 and does x++ on each iteration. +// +// - An induction variable is an element of the loop's tuple that is not the +// trip counter and does `x += ` on each iteration of the loop. +// Negative constants are OK. +// +// This pass adds a trip counter if one isn't already present, then replaces +// each induction variable with +// +// + * . +// +// This reduces the number of scalar operations in the loop, which is important +// e.g. on GPUs, where each scalar operation is nontrivially expensive because +// it's a separate kernel launch. +// +// Returns the new loop if a change was made, or null if no change was made. +// Note that the new loop is not a valid replacement for the old loop; it may +// need to be wrapped in a tuple that changes its shape. We return the loop +// itself so that you can call TryMergeInductionVariables in a loop, once for +// each integral type elem_ty. +static StatusOr TryMergeInductionVariables( + HloInstruction* while_op, PrimitiveType elem_ty) { + CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return nullptr; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + + // The tuple index of the trip counter, if one is present. + absl::optional trip_counter; + // Maps the tuple index of each induction variable to its constant increment. + absl::flat_hash_map induction_vars; + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + HloInstruction* constant; + if (!Match(while_body_root->mutable_operand(i), + m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i), + m::ConstantScalar(&constant)) + .WithShape(m::Shape().WithElementType(elem_ty)))) { + continue; + } + if (!trip_counter && constant->literal().IsAll(1) && + while_init->operand(i)->IsConstant() && + while_init->operand(i)->literal().IsAll(0)) { + VLOG(10) << "Found existing trip counter at index " << i; + trip_counter = i; + } else { + VLOG(10) << "Found induction variable at index " << i; + induction_vars.emplace(i, Cast(constant)); + } + } + + // There's only something to simplify if we can either: + // + // - combine one or more induction vars with an existing trip counter, or + // - replace two or more induction variables with a new trip counter. + // + // Put another way, there's only something to simplify if the number of + // induction vars plus the number of existing trip counters (0 or 1) is >= 2. + if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) { + return nullptr; + } + + // OK, we're going to do the transformation! Set up some helpers. + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto add_binary_op = [&](const Shape& shape, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // Reshape lhs/rhs to the output shape if necessary. This deals with the + // fact that induction variables need only be effective scalars, not true + // scalars. + if (!ShapeUtil::Compatible(shape, lhs->shape())) { + lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs)); + } + if (!ShapeUtil::Compatible(shape, rhs->shape())) { + rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs)); + } + return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs)); + }; + + auto add_gte = [&](HloInstruction* src, int64 idx) { + return add_new_instr(HloInstruction::CreateGetTupleElement( + src->shape().tuple_shapes(idx), src, idx)); + }; + + // Our new while loop will have the same shape as the old while loop, except + // we'll add a trip counter to the end if it wasn't originally present. + Shape new_while_shape = while_shape; + bool added_trip_counter = false; + if (!trip_counter) { + VLOG(10) << "Adding new trip counter to end of loop's tuple."; + trip_counter = new_while_shape.tuple_shapes_size(); + *new_while_shape.add_tuple_shapes() = + ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{}); + added_trip_counter = true; + } + + // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with + // shape `while_body->shape()` and where the induction variables are "reified" + // (i.e. they have value + * ). + auto convert_to_old_form = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + const auto& elem_shape = while_shape.tuple_shapes(i); + if (!induction_vars.count(i)) { + tuple_elems.push_back(add_gte(instr, i)); + continue; + } + tuple_elems.push_back(add_binary_op( + elem_shape, HloOpcode::kAdd, add_gte(instr, i), + add_binary_op(elem_shape, HloOpcode::kMultiply, + add_gte(instr, *trip_counter), + add_new_instr(induction_vars.at(i)->Clone())))); + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Converts `root` into a tuple of the "new" form -- that is, to a tuple with + // shape `new_while_shape` and where the induction variables (but not trip + // counters) are replaced with their unchanging values. + auto convert_to_new_form = [&](HloInstruction* old_root, + HloParameterInstruction* loop_body_param) { + CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape)); + std::vector tuple_elems; + + // In the new form, induction variables come from `init`, everything else + // (including the trip counter if it's not one we created ourselves) comes + // from the `root` tuple unmodified. + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back( + add_gte((induction_vars.count(i) ? loop_body_param : old_root), i)); + } + // If we created a trip counter ourselves, add 1 to it in the next + // iteration. + if (added_trip_counter) { + tuple_elems.push_back(add_binary_op( + new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd, + add_gte(loop_body_param, *trip_counter), + add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::One(elem_ty))))); + } + + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Creates a new init tuple, which is the same as the old init tuple except if + // we added a trip counter, it's set to 0. + auto get_new_while_init = [&](HloInstruction* init) { + CHECK(ShapeUtil::Compatible(init->shape(), while_shape)); + if (!added_trip_counter) { + return init; + } + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back(add_gte(init, i)); + } + tuple_elems.push_back(add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty)))); + return add_new_instr(HloInstruction::CreateTuple(tuple_elems)); + }; + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Creating the new while body proceeds in two steps. First we convert the + // users of the parameter to the old form. Then as a second + // CloneWithReplacement operation we convert the root to the new form. We + // have to do this in two steps because the new root needs to use the new + // param0, and during the first clone operation, only the *old-form* param0 is + // accessible. + // + // We have to add temp_new_while_body to the module because cloning a + // computation touches the module (to get its NameUniquer). + HloComputation* temp_new_while_body = + module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({ + while_body->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_body->parameter_instruction(0)->name()))), + })); + std::unique_ptr new_while_body = + temp_new_while_body->CloneWithReplacementPairs({ + temp_new_while_body->root_instruction(), + convert_to_new_form( + add_new_instr(temp_new_while_body->root_instruction()->Clone()), + Cast( + temp_new_while_body->parameter_instruction(0))), + }); + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + get_new_while_init(while_init))); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, convert_to_old_form(new_while))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return new_while; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -650,19 +1022,50 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } + // TODO(b/119281462): Cowardly refuse to perform any of the following + // optimizations in the presence of kDomain instructions. It seems that + // modifying a while loop's tuple doesn't work when kDomain is present. + if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kDomain})) { + continue; + } + + // Each of the optimizations below modifies the while loop itself if it's + // successful, meaning that `while_op` is no longer valid after one of these + // transformations returns true. + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); changed |= result; if (result) { - // Successfully flattening nested tuples results in us cloning and - // replacing the while loop, meaning that `while_op` is no longer valid. continue; } TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); changed |= result; if (result) { - // Successfully removing dead while params results in us cloning and - // replacing the while loop, meaning that `while_op` is no longer valid. + continue; + } + + TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); + changed |= result; + if (result) { + continue; + } + + bool merged_induction_vars = false; + // Notably missing from this list are S16 and U16. These don't currently + // work because S/U16 literals are not implemented. + for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { + TF_ASSIGN_OR_RETURN(auto* new_while_op, + TryMergeInductionVariables(while_op, elem_ty)); + if (new_while_op) { + while_op = new_while_op; + changed = true; + merged_induction_vars = true; + } + } + if (merged_induction_vars) { continue; } } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 05005e0b262a50cd40e004deac4c450a2e257308..4950e8269e9cf0723d717bd1734518d104c0c9f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -27,8 +30,17 @@ limitations under the License. namespace xla { namespace { +using ::testing::_; namespace op = xla::testing::opcode_matchers; +// Returns the first kWhile instruction within m's entry computation. +HloInstruction* FindFirstWhile(HloModule* m) { + const auto& instrs = m->entry_computation()->instructions(); + return *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); +} + class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. @@ -540,11 +552,7 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { // it easy to find. EXPECT_TRUE(HloDCE().Run(m.get()).ok()); - const auto& instrs = m->entry_computation()->instructions(); - HloInstruction* new_while = - *absl::c_find_if(instrs, [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + HloInstruction* new_while = FindFirstWhile(m.get()); Shape flat_tuple = ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") .ValueOrDie(); @@ -563,5 +571,177 @@ TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { .ValueOrDie())); } +// Edge-case: All elements of the loop carry are constants which can be removed, +// leaving us with a nullary loop. This is a special case, we just replace the +// loop with its init. +TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1]) parameter(0) + a = s32[1] constant({0}) + ROOT tuple = (s32[1]) tuple(a) + } + Cond { + param = (s32[1]) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + init = (s32[1]) tuple(a) + ROOT while = (s32[1]) while(init), condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1], s32[2], s32[3]) parameter(0) + a = s32[1] get-tuple-element(param), index=0 + a.1 = s32[1] add(a, a) + b = s32[2] constant({1,1}) + c = s32[3] constant({10,10,10}) + ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c) + } + Cond { + param = (s32[1], s32[2], s32[3]) parameter(0) + /* Use each tuple element. The verifier will then ensure that if any of + * these get modified, they're replaced with values of the correct shape. */ + a = s32[1] get-tuple-element(param), index=0 + b = s32[2] get-tuple-element(param), index=1 + c = s32[3] get-tuple-element(param), index=2 + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + /* Only `b` should be simplified away. `a` is not a constant within the + * loop, and `c`'s value changes depending on whether we run 0 or 1 + * iterations of the loop. */ + a = s32[1] constant({0}) + b = s32[2] constant({1,1}) + c = s32[3] constant({2,2,2}) + init = (s32[1], s32[2], s32[3]) tuple(a,b,c) + ROOT while = (s32[1], s32[2], s32[3]) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + // Run the tuple simplifier to make the resulting HLO a bit easier to check. + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(_, op::Constant(), _)); +} + +const char* const kSimpleMergeInductionVariablesModule = R"( + HloModule Test + Body { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + + a = TYPE[] get-tuple-element(param), index=0 + one = TYPE[] constant(1) + a1 = TYPE[] add(a, one) + + b = TYPE[] get-tuple-element(param), index=1 + negone = TYPE[] constant(-1) + b1 = TYPE[] add(b, negone) + + c = TYPE[] add(a, b) + + ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c) + } + Cond { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + a = TYPE[] get-tuple-element(param), index=0 + b = TYPE[] get-tuple-element(param), index=1 + sum = TYPE[] power(a, b) + ten = TYPE[] constant(10) + ROOT cond = pred[] less-than(sum, ten) + } + ENTRY Loop { + a = TYPE[] constant(10) + b = TYPE[] constant(100) + c = TYPE[] constant(0) + init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c) + while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body + + a1 = TYPE[] get-tuple-element(while), index=0 + b1 = TYPE[] get-tuple-element(while), index=1 + ROOT sum = TYPE[] add(a1, b1) + })"; + +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s32"}}); + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find, and run the tuple simplifier to make the resulting HLO + // easier to check. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + // We should have added a new loop counter for s32[] to the end of the tuple. + SCOPED_TRACE(m->ToString()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + + EXPECT_THAT(new_while->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1), op::Add(), + op::Add(op::GetTupleElement(op::Parameter(), 3), + op::Constant()))); + EXPECT_THAT(new_while->while_condition()->root_instruction(), + op::Lt(op::Power(op::Add(), op::Add()), op::Constant())); +} + +// We shouldn't merge S16 induction variables; we can't create constants of this +// type because S16 literals are not implemented. +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s16"}}); + EXPECT_FALSE( + WhileLoopSimplifier() + .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get()) + .ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..746ab9e9977b1b10cdb0cb57197027d65bd50f55 --- /dev/null +++ b/tensorflow/compiler/xla/shape.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +Shape::Shape(const ShapeProto& shape_proto) { + set_element_type(shape_proto.element_type()); + dimensions_.reserve(shape_proto.dimensions_size()); + for (const int64 dimension : shape_proto.dimensions()) { + add_dimensions(dimension); + } + tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); + for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { + *add_tuple_shapes() = Shape(element_shape); + } + if (shape_proto.has_layout()) { + *mutable_layout() = shape_proto.layout(); + } +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.set_element_type(element_type_); + proto.mutable_dimensions()->Reserve(dimensions_size()); + for (const int64 dimension : dimensions()) { + proto.add_dimensions(dimension); + } + proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); + for (const Shape& shape : tuple_shapes()) { + *proto.add_tuple_shapes() = shape.ToProto(); + } + if (has_layout()) { + *proto.mutable_layout() = layout(); + } + return proto; +} + +string Shape::ToString(bool print_layout) const { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(*this); + } else { + return ShapeUtil::HumanString(*this); + } +} + +std::ostream& operator<<(std::ostream& out, const Shape& shape) { + out << shape.ToString(/*print_layout=*/true); + return out; +} + +ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) { + for (const ShapeProto& shape_proto : program_shape_proto.parameters()) { + *add_parameters() = Shape(shape_proto); + } + *mutable_result() = Shape(program_shape_proto.result()); + for (const string& name : program_shape_proto.parameter_names()) { + add_parameter_names(name); + } +} + +ProgramShapeProto ProgramShape::ToProto() const { + ProgramShapeProto proto; + for (const Shape& shape : parameters()) { + *proto.add_parameters() = shape.ToProto(); + } + *proto.mutable_result() = result().ToProto(); + for (const string& name : parameter_names()) { + proto.add_parameter_names(name); + } + return proto; +} + +string ProgramShape::ToString() const { + std::vector parameter_strings(parameters_size()); + for (int i = 0; i < parameters_size(); ++i) { + parameter_strings[i] = absl::StrCat( + i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ", + ShapeUtil::HumanString(parameters(i))); + } + return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ", + ShapeUtil::HumanString(result())); +} + +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) { + out << program_shape.ToString() << "\n"; + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..7f6b14ab4286c696dce64d2250a3fe8a57e4865b --- /dev/null +++ b/tensorflow/compiler/xla/shape.h @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A shape describes the number of dimensions in a array, the bounds of each +// dimension, and the primitive component type. For tuples, shape describes the +// structure (number of elements and nesting). +class Shape { + public: + Shape() = default; + + // Construct a shape from a ShapeProto. + explicit Shape(const ShapeProto& shape_proto); + + // Returns a ShapeProto representation of the Shape. + ShapeProto ToProto() const; + + // Returns a human-readable string that represents the given shape, with or + // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". + string ToString(bool print_layout = false) const; + + // The following methods mirror the protobuf generated code interface for the + // message ShapeProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the primitive type. + PrimitiveType element_type() const { return element_type_; } + void set_element_type(PrimitiveType value) { element_type_ = value; } + + // Methods for accessing the dimensions array. + int dimensions_size() const { return dimensions_.size(); } + int64 dimensions(int index) const { return dimensions_.at(index); } + void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } + void add_dimensions(int64 value) { dimensions_.push_back(value); } + void clear_dimensions() { dimensions_.clear(); } + const std::vector& dimensions() const { return dimensions_; } + std::vector* mutable_dimensions() { return &dimensions_; } + + // Methods for accessing the tuple subshapes. This field only non-empty for + // tuple shapes. + int tuple_shapes_size() const { return tuple_shapes_.size(); } + const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } + Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } + Shape* add_tuple_shapes() { + tuple_shapes_.push_back(Shape()); + return &tuple_shapes_.back(); + } + void clear_tuple_shapes() { tuple_shapes_.clear(); } + const std::vector& tuple_shapes() const { return tuple_shapes_; } + std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } + + // Methods for accessing the layout field. + bool has_layout() const { return layout_.has_value(); } + const Layout& layout() const { + if (layout_.has_value()) { + return *layout_; + } else { + return Layout::default_instance(); + } + } + Layout* mutable_layout() { + if (!layout_.has_value()) { + layout_ = Layout(); + } + return &layout_.value(); + } + void clear_layout() { layout_.reset(); } + + void Swap(Shape* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + element_type_ = PRIMITIVE_TYPE_INVALID; + dimensions_.clear(); + tuple_shapes_.clear(); + layout_.reset(); + } + + string SerializeAsString() const { return ToProto().SerializeAsString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + public: + // The element type of this shape (tuple, array, etc). + PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; + + // The array bounds of the dimensions. This is nonempty only for array shapes. + std::vector dimensions_; + + // The tuple element subshapes. This is nonempty only for tuple shapes. + std::vector tuple_shapes_; + + // The array layout of the shape. This is present only for array shapes. + absl::optional layout_; +}; + +// Shape of the parameters and output of an XLA computation. This is analogous +// to a traditional function signature. +class ProgramShape { + public: + ProgramShape() = default; + + // Creates a ProgramShape from a ProgramShapeProto protobuf. + explicit ProgramShape(const ProgramShapeProto& program_shape_proto); + + // Returns a proto representation of the object. + ProgramShapeProto ToProto() const; + + string ToString() const; + + // The following methods mirror the protobuf generated code interface for the + // message ProgramShapeProto. This enabled easy migration of this data + // structure from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing and manipulating the Shape of the parameters. + int parameters_size() const { return parameters_.size(); } + const Shape& parameters(int index) const { return parameters_.at(index); } + Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + Shape* add_parameters() { + parameters_.emplace_back(); + return ¶meters_.back(); + } + void clear_parameters() { parameters_.clear(); } + const std::vector& parameters() const { return parameters_; } + std::vector* mutable_parameters() { return ¶meters_; } + + // Methods for accessing and manipulating the Shape of the result. + const Shape& result() const { return result_; } + Shape* mutable_result() { return &result_; } + + // Methods for accessing and manipulating the names of the parameters. + int parameter_names_size() const { return parameter_names_.size(); } + const string& parameter_names(int index) const { + return parameter_names_.at(index); + } + void set_parameter_names(int index, const string& value) { + parameter_names_.at(index) = value; + } + string* mutable_parameter_names(int index) { + return ¶meter_names_.at(index); + } + void add_parameter_names(const string& value) { + parameter_names_.push_back(value); + } + string* add_parameter_names() { + parameter_names_.push_back(""); + return ¶meter_names_.back(); + } + void clear_parameter_names() { parameter_names_.clear(); } + const std::vector& parameter_names() const { + return parameter_names_; + } + std::vector* mutable_parameter_names() { return ¶meter_names_; } + + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + private: + // The shapes of the parameters of the computation represented by this object. + std::vector parameters_; + + // The names of the parameters of the computation represented by this object. + std::vector parameter_names_; + + // The shape of the result of the computation represented by this object. + Shape result_; +}; + +std::ostream& operator<<(std::ostream& out, const Shape& shape); +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e396897eeebc2e7bdc2dc49300c8906710608b05 --- /dev/null +++ b/tensorflow/compiler/xla/shape_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ShapeTest : public ::testing::Test { + protected: + const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); + const Shape token_ = ShapeUtil::MakeTokenShape(); + const Shape scalar_ = ShapeUtil::MakeShape(F32, {}); + const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); + const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); + const Shape tuple_ = + ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); + const Shape nested_tuple_ = + ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); +}; + +TEST_F(ShapeTest, ShapeToFromProto) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + Shape shape_copy(shape.ToProto()); + EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) + << shape << " != " << shape_copy; + } +} + +TEST_F(ShapeTest, ShapeToString) { + EXPECT_EQ("opaque[]", opaque_.ToString()); + EXPECT_EQ("token[]", token_.ToString()); + EXPECT_EQ("f32[]", scalar_.ToString()); + EXPECT_EQ("u32[1,2]", matrix_.ToString()); + EXPECT_EQ("s32[3,4]", matrix2_.ToString()); + EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString()); + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + nested_tuple_.ToString()); + + EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true)); + EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true)); + EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true)); + EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true)); + EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", + tuple_.ToString(/*print_layout=*/true)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + nested_tuple_.ToString(/*print_layout=*/true)); +} + +TEST_F(ShapeTest, ProgramShapeToFromProto) { + ProgramShape program_shape; + *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + *program_shape.add_parameters() = ShapeUtil::MakeTokenShape(); + *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {}); + *program_shape.add_parameters() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeShape(F32, {42, 42})}); + + *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7}); + + program_shape.add_parameter_names("foo"); + program_shape.add_parameter_names("bar"); + program_shape.add_parameter_names("baz"); + program_shape.add_parameter_names("qux qux"); + + // Create a copy of the program shape by round-tripping through a proto. + ProgramShape program_shape_copy(program_shape.ToProto()); + ASSERT_EQ(program_shape.parameters_size(), + program_shape_copy.parameters_size()); + for (int i = 0; i < program_shape.parameters_size(); ++i) { + EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i), + program_shape_copy.parameters(i))); + } + + EXPECT_TRUE( + ShapeUtil::Equal(program_shape.result(), program_shape_copy.result())); + + ASSERT_EQ(program_shape.parameter_names_size(), + program_shape_copy.parameter_names_size()); + for (int i = 0; i < program_shape.parameter_names_size(); ++i) { + EXPECT_EQ(program_shape.parameter_names(i), + program_shape_copy.parameter_names(i)); + } +} + +TEST_F(ShapeTest, ProgramShapeToString) { + ProgramShape prog = ShapeUtil::MakeProgramShape( + {opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}, + nested_tuple_); + EXPECT_EQ( + "((unknown): opaque[], " + "(unknown): f32[], " + "(unknown): u32[1,2], " + "(unknown): s32[3,4], " + "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); + + prog.add_parameter_names("arg0"); + prog.add_parameter_names("scalar"); + prog.add_parameter_names("matrix"); + prog.add_parameter_names("matrix2"); + prog.add_parameter_names("tuple"); + prog.add_parameter_names("nested_tuple"); + EXPECT_EQ( + "(arg0: opaque[], " + "scalar: f32[], " + "matrix: u32[1,2], " + "matrix2: s32[3,4], " + "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index df610102b4c7fa08c0b7030124939009130f89f4..7bf97729165bef98fabc29040e02203eee68a53c 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -667,12 +667,11 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; - ForEachElement( - [this, &other, &equal](const ShapeIndex& index, const T& data) { - if (data != other.element(index)) { - equal = false; - } - }); + ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) { + if (data != other.element(index)) { + equal = false; + } + }); return equal; } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c8ff55e7845785d9292516b823fb591cc28cbfad..2b6c484bc4f205be0180403eeac2dd391029b110 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -52,10 +52,10 @@ class ShapeTreeTest : public ::testing::Test { TEST_F(ShapeTreeTest, DefaultConstructor) { ShapeTree int_tree; - EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(int_tree.shape())); ShapeTree bool_tree; - EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(bool_tree.shape())); } void ShapeTreeTest::TestShapeConstructor(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d0c35d8dee46a1e0a5e343e0506a14ca1ce38bfd..a4d4e1e53e727bdf7822cacaa4559fcae59d4eae 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -79,14 +79,14 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { indices_.subspan(0, prefix.size()) == prefix.indices_; } -namespace { - -// Returns whether the given primitive type corresponds to an array shape. -bool IsArrayPrimitiveType(PrimitiveType primitive_type) { +/* static */ bool ShapeUtil::IsArrayPrimitiveType( + PrimitiveType primitive_type) { return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && primitive_type != OPAQUE && primitive_type != TOKEN; } +namespace { + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. @@ -121,6 +121,23 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } + + const auto& lhs_tiles = lhs.layout().tiles(); + const auto& rhs_tiles = rhs.layout().tiles(); + if (lhs_tiles.size() != rhs_tiles.size()) { + return false; + } + for (int64 i = 0; i < lhs_tiles.size(); i++) { + if (!absl::c_equal(lhs_tiles[i].dimensions(), + rhs_tiles[i].dimensions())) { + return false; + } + } + + if (lhs.layout().element_size_in_bits() != + rhs.layout().element_size_in_bits()) { + return false; + } } } @@ -203,7 +220,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ ProgramShape ShapeUtil::MakeProgramShape( std::initializer_list parameters, Shape result) { ProgramShape program_shape; - for (const auto& shape : parameters) { + for (const Shape& shape : parameters) { *program_shape.add_parameters() = shape; } *program_shape.mutable_result() = std::move(result); @@ -272,7 +289,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); - result.mutable_tuple_shapes()->Reserve(shapes.size()); + result.mutable_tuple_shapes()->reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -372,10 +389,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsTuple(shape) && TupleElementCount(shape) == 0; } -/* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape); -} - /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { CHECK(IsTuple(shape)) << HumanString(shape); return shape.tuple_shapes_size(); @@ -571,7 +584,7 @@ namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; @@ -584,7 +597,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); @@ -1155,7 +1168,7 @@ Status ForEachMutableSubshapeHelper( // Let the argument `permutation` be P. This is a permutation over `shape`'s // dimensions, so our return value will be a shape with dims P.I = P. Our // goal is to construct a layout permutation L* that we can apply to P such - // that that the physical dimension ordering of the returned shape is the same + // that the physical dimension ordering of the returned shape is the same // as that of the original shape, namely L'. // // Our returned shape has dims P and layout L*, so its in-memory layout is @@ -1600,7 +1613,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); + shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); @@ -1634,11 +1648,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } -std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanStringWithLayout(shape); - return out; -} - /*static*/ size_t ShapeUtil::Hash(const Shape& shape) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index a7a3026cf3f3a53d34d389212738ca584a19db1d..84a27f662a57ba274562e2e9be57b7e971c9b477 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -100,6 +102,11 @@ class ShapeIndex { string ToString() const; + template + friend H AbslHashValue(H h, const ShapeIndex& index) { + return H::combine(std::move(h), index.indices_); + } + private: container_type indices_; }; @@ -461,6 +468,9 @@ class ShapeUtil { // arrays. static bool IsArray(const Shape& shape); + // Returns whether the given primitive type corresponds to an array shape. + static bool IsArrayPrimitiveType(PrimitiveType primitive_type); + // Returns whether the shape is a tuple with at least one element which is // also a tuple. static bool IsNestedTuple(const Shape& shape); @@ -468,9 +478,6 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is the nil shape (an empty tuple). - static bool IsNil(const Shape& shape); - // Returns the number of elements in the given tuple shape. // Precondition: IsTuple(shape) static int64 TupleElementCount(const Shape& shape); @@ -754,10 +761,18 @@ class ShapeUtil { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } + tensorflow::mutex mu; + Status status; // Guarded by mu + while (n < rank) { if (pool != absl::nullopt) { - pool->Schedule( - [indexes, &visitor_function] { visitor_function(indexes); }); + pool->Schedule([indexes, &visitor_function, &mu, &status] { + StatusOr result = visitor_function(indexes); + if (!result.ok()) { + tensorflow::mutex_lock lock(mu); + status = status.ok() ? result.status() : status; + } + }); } else { TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); if (!should_continue) { @@ -775,14 +790,14 @@ class ShapeUtil { } } - return Status::OK(); + // Waits for the scheduled work to complete. + pool.reset(); + return status; } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; -std::ostream& operator<<(std::ostream& out, const Shape& shape); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0c647369a37e70f93abe1732963d2ddc7730c214..60bdbe302045e6f3b4bae500c50bc68fb217525d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -376,12 +376,12 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { } TEST(ShapeUtilTest, NilShape) { - EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); } @@ -546,68 +546,6 @@ TEST(ShapeUtilTest, IsLeafIndex) { EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); } -TEST(ShapeUtilTest, HumanString) { - Shape opaque = ShapeUtil::MakeOpaqueShape(); - Shape token = ShapeUtil::MakeTokenShape(); - Shape scalar = ShapeUtil::MakeShape(F32, {}); - Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); - Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); - Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); - EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); - EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); - EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); - EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", - ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(nested_tuple)); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); - EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar)); - EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix)); - EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", - ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ( - "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " - "token[])", - ShapeUtil::HumanStringWithLayout(nested_tuple)); - - ProgramShape prog = ShapeUtil::MakeProgramShape( - {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); - EXPECT_EQ( - "((unknown): opaque[], " - "(unknown): f32[], " - "(unknown): u32[1,2], " - "(unknown): s32[3,4], " - "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); - - prog.add_parameter_names("arg0"); - prog.add_parameter_names("scalar"); - prog.add_parameter_names("matrix"); - prog.add_parameter_names("matrix2"); - prog.add_parameter_names("tuple"); - prog.add_parameter_names("nested_tuple"); - EXPECT_EQ( - "(arg0: opaque[], " - "scalar: f32[], " - "matrix: u32[1,2], " - "matrix2: s32[3,4], " - "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " - "token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); -} - TEST(ShapeUtilTest, ForEachSubshapeArray) { const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); int calls = 0; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index db34d34f969311543d988ec6c3b8ee2af5b07e8e..5a7a4faa7e89b27fb537f20d94c21cb4a76e000d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -79,6 +79,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -135,6 +136,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -297,6 +299,52 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_test", + timeout = "long", + srcs = ["conv_depthwise_test.cc"], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "grouped_convolution_test", + timeout = "long", + srcs = ["grouped_convolution_test.cc"], + blacklisted_backends = [ + # disabled because of a break b/119590850. + "gpu", + # disabled because it times out. + "cpu", + ], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], @@ -1265,6 +1313,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1865,6 +1914,7 @@ xla_test( xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], + backends = ["gpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 2180b22cb3bc2e1cdd484098bafd14315d1fa142..915b456b52215f8d6a9eb6c5b933f3502f1d3d2c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); @@ -2744,12 +2782,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { -{ { 0, 1 }, +{ + { 0, 1 }, { 0, 0 }, - { 0, 0 } }, -{ { 0, 1 }, + { 0, 0 } +}, +{ + { 0, 1 }, { 1, 0 }, - { 0, 1 } } + { 0, 1 } +} })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index dde19fb65d65064c9452a6ac49c70e20cf113336..702fb32adfc8a0ded26845c92245776a79777c34 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -161,8 +161,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {1}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -175,8 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {0}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {0}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -189,8 +187,8 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 1}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -207,8 +205,8 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 2}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -225,8 +223,7 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {3, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {3, 2}, {1}); Array2D expected(3, 2); expected(0, 0) = 1; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b98572e24c831c1ff746904302cacccb20056207..12c029983336cc9aed0fde4ce6881c9a00a9869e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -107,7 +107,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransfer( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } return client_->ExecuteAndTransfer(computation, arguments, &execution_options); @@ -127,7 +127,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } execution_options.clear_device_handles(); return ref_client_->ExecuteAndTransfer(computation, arguments, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34148e5886d3806b19fc5bee90806c5678df345e..65a23dd883594b9bf9c37494a37e9be39b197788 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,7 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); - opts->set_xla_gpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_min_max(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 6f2ca84bb646e88af221ab80b727911ff7d990eb..363dee74b2755a6bdc3c5a5164a85378581c21d2 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -50,7 +50,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - execute_layout); + execute_layout) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); @@ -84,7 +85,8 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - /*minor_to_major=*/{1, 0})}); + /*minor_to_major=*/{1, 0})}) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( auto result, diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9811a015e91d866d6f4de6ebb6dac536ed6c7e06..4f5b525a34252db9e967a55af0d1bf39a2dd830e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -492,6 +492,32 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } +XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { + XlaBuilder builder(TestName()); + auto a_literal = LiteralUtil::CreateR1({256.0}); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b = ConcatInDim(&builder, {a, a}, 0); + auto c = ConcatInDim(&builder, {b, b}, 0); + auto d = ConcatInDim(&builder, {c, c}, 0); + auto e = ConcatInDim(&builder, {d, d}, 0); + auto f = ConcatInDim(&builder, {e, e}, 0); + auto g = ConcatInDim(&builder, {f, f}, 0); + auto h = ConcatInDim(&builder, {g, g}, 0); + auto i = ConcatInDim(&builder, {h, h}, 0); + auto j = ConcatInDim(&builder, {i, i}, 0); + auto k = ConcatInDim(&builder, {j, j}, 0); + auto l = ConcatInDim(&builder, {k, k}, 0); + auto m = ConcatInDim(&builder, {l, l}, 0); + auto n = ConcatInDim(&builder, {m, m}, 0); + auto o = ConcatInDim(&builder, {n, n}, 0); + auto p = ConcatInDim(&builder, {o, o}, 0); + auto q = ConcatInDim(&builder, {p, p}, 0); + ConcatInDim(&builder, {q, q}, 0); + std::vector expected(131072, 256.0); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, expected, {a_data.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..627a17a0ca114085240dbaf28211bb3511cf0cab --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct DepthwiseConvolution2DSpec { + int64 output_feature, window, stride, pad, lhs_dilate; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class DepthwiseConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; + + for (auto option : config_options) { + int64 feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + + std::vector kernel_layout = {3, 2, 1, 0}; + DepthwiseConvolution2DSpec config; + config.output_feature = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, 1, feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, feature}; + } else if (feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = feature / 32; + config.output_dims = {batch, feature / 32, + activation_size - kernel_size + 1, feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.output_feature); + + } else if (spec.stride == -1) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.output_feature); + } else { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); + } +} + +XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { + const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + DepthwiseConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 211d004ec8c0a04b17c2454995880c0b565d3d4d..4a58a1ed66c438d1dd9561f4eb029b38d8c6cbdd 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -721,23 +721,573 @@ class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 1024}; + std::vector filter_dims = {3, 3, 1, 1024}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); - auto filter_r = filter_r1.Reshape(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1024); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(4096, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); } }; -TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { this->RunTest(); } template -class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { +class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); - std::vector input_dims = {1, 4, 4, 160}; - std::vector filter_dims = {3, 3, 1, 160}; + std::vector input_dims = {1, 2, 2, 6}; + std::vector filter_dims = {2, 2, 2, 12}; Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { @@ -760,23 +1310,89 @@ class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { dnums.set_kernel_output_feature_dimension(3); ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, - /*feature_group_count=*/160); + /*feature_group_count=*/3); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(5076), static_cast(5160), static_cast(5244), + static_cast(5328), static_cast(6164), static_cast(6264), + static_cast(6364), static_cast(6464), static_cast(7380), + static_cast(7496), static_cast(7612), static_cast(7728)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/8); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); - std::vector output_elems(640, static_cast(18)); - + std::vector output_elems(512, static_cast(1024)); auto expected_r1 = LiteralUtil::CreateR1(output_elems); - auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(input_r4).ConsumeValueOrDie(); @@ -786,24 +1402,21 @@ class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); - - auto filter_r = filter_r1.Reshape(filter_dims); } }; -TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) { this->RunTest(); } template -class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid - : public ConvolutionTest { +class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); - std::vector input_dims = {1, 4, 4, 1024}; - std::vector filter_dims = {3, 3, 1, 1024}; + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 8}; Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { @@ -826,23 +1439,24 @@ class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid dnums.set_kernel_output_feature_dimension(3); ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, - /*feature_group_count=*/1024); + /*feature_group_count=*/8); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); - std::vector output_elems(4096, static_cast(18)); - + std::vector output_elems(8, static_cast(1024)); auto expected_r1 = LiteralUtil::CreateR1(output_elems); - auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(input_r4).ConsumeValueOrDie(); @@ -852,23 +1466,21 @@ class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); - - auto filter_r = filter_r1.Reshape(filter_dims); } }; -TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) { this->RunTest(); } template -class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); - std::vector input_dims = {1, 2, 2, 6}; - std::vector filter_dims = {2, 2, 2, 12}; + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 3, 4}; Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { @@ -891,7 +1503,7 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { dnums.set_kernel_output_feature_dimension(3); ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, - /*feature_group_count=*/3); + /*feature_group_count=*/4); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -904,12 +1516,140 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(7712), static_cast(8816), + static_cast(9992), static_cast(11240)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 4, 3}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(3); + dnums.set_kernel_output_feature_dimension(2); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4_relaid = + filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); auto expected_r1 = LiteralUtil::CreateR1( - {static_cast(5076), static_cast(5160), static_cast(5244), - static_cast(5328), static_cast(6164), static_cast(6264), - static_cast(6364), static_cast(6464), static_cast(7380), - static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + {static_cast(6968), static_cast(8516), static_cast(10280), + static_cast(12260)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4_relaid).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 1, 1, 12}; + std::vector filter_dims = {1, 1, 3, 4}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(38), static_cast(98), + static_cast(176), static_cast(272)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(input_r4).ConsumeValueOrDie(); @@ -922,8 +1662,8 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) { this->RunTest(); } @@ -1217,6 +1957,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF32ForwardReversed)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f32[3,56,56,16] parameter(0) + %arg1 = f32[3,3,3,32] parameter(1) + ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6c0847a875798870b4362a99ac2ab65d99f9f3e6..c5d8b663f4abe77e05ec213d2e4e075c260a8655 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace xla { namespace { @@ -637,6 +636,76 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } +#ifndef XLA_TEST_BACKEND_CPU +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + auto y_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} + +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "x"); + auto y = + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} +#endif // XLA_TEST_BACKEND_CPU + XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { using T = TypeParam; diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f7049910e70c4e591636a47c1b6ba72cf2c234f --- /dev/null +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct GroupedConvolution2DSpec { + int64 input_feature, output_feature, window, stride, pad, lhs_dilate; + int64 group_size, group_count; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class GroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + // Add to this set if you want a new test configuration. + // Rule : the penultimate number must be divisible by the last number. + std::vector> config_options = {{8, 2, 2, 1, 1024, 128}, + {512, 3, 3, 144, 1024, 16}, + {256, 3, 3, 129, 512, 64}, + {64, 1, 2, 127, 32, 8}, + {256, 3, 3, 256, 1024, 4}}; + + for (auto option : config_options) { + int64 output_feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + int64 input_feature = option[4]; + int64 group_size = option[5]; + + std::vector kernel_layout = {3, 2, 1, 0}; + GroupedConvolution2DSpec config; + config.group_size = group_size; + config.group_count = input_feature / group_size; + config.output_feature = output_feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, + input_feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, group_size, output_feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, output_feature}; + } else if (output_feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = output_feature / 32; + config.output_dims = {batch, output_feature / 32, + activation_size - kernel_size + 1, output_feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, output_feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string GroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextGroupedConvolution2D(const GroupedConvolution2DSpec& spec, + bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + // Check for outer dim. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.group_count); + + } else if (spec.stride == -1) { + // Check for basic, non-dilated cases. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.group_count); + } else { + // Check for base dilations. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.group_count); + } +} + +XLA_TEST_P(GroupedConvolution2DTest, DoIt) { + const GroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = BuildHloTextGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + GroupedConvolution2DTestWithRandomIndices, GroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + GroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d8fa00272f8f19ab843fd32a66fd6d6842997bdb..989a7c705a8254f99e5cc0e97dfde5942f146964 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -99,6 +99,8 @@ void VerifiedHloModule::VerifyOrAddFailure(const string& message) { ADD_FAILURE() << "HloVerifier failed on module " << name() << (message.empty() ? "" : absl::StrCat(" (", message, ")")) << ": " << status; + LOG(ERROR) << "Contents of bad module:"; + XLA_LOG_LINES(tensorflow::ERROR, ToString()); } } @@ -140,14 +142,6 @@ std::unique_ptr HloTestBase::CreateNewVerifiedModule( allow_mixed_precision_in_hlo_verifier_); } -StatusOr> -HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - auto module = absl::make_unique(TestName(), config); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - return std::move(module); -} - StatusOr> HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, const HloModuleConfig& config) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 366726d90b4752b6d53dc2133c8b0b5bbafce086..1d1e7f437296a7493ef7da07039fcf6d273f35bc 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -100,6 +101,7 @@ class HloTestBase : public ::testing::Test { // // This returns a vanilla HloModule that doesn't run the HLO verifier on // destruction. + ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") std::unique_ptr CreateNewUnverifiedModule( const string& name = TestName()); @@ -108,12 +110,6 @@ class HloTestBase : public ::testing::Test { std::unique_ptr CreateNewVerifiedModule( const string& name = TestName()); - // Parses the given string and returns module as a vanilla, unverified - // HloModule. - StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - // Parses the given string and returns module as a VerifiedHloModule. StatusOr> ParseAndReturnVerifiedModule( absl::string_view hlo_text, diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 310f3495922250d68aa463fcbb24ef0b04603d09..65205f53ddc582ae477d67705f161fef1e31b857 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -113,5 +113,26 @@ INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, /*step=*/10), ::testing::Values(0, 1, 2))); +class IotaR3PredTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(IotaR3PredTest, DoIt) { + const auto element_type = PRED; + const int64 num_elements = 2; + const int64 iota_dim = GetParam(); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3PredTestInstantiation, IotaR3PredTest, + ::testing::Values(0, 1, 2)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 5cf87e565bf493167f5173588e7afa3b96282488..34c7dc7c46427b2d18ea21fc286ee03175f70800 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = @@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. std::unique_ptr x_data = @@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index dedc95b5ae8315185a35f786af42aad53bd7ad96..298136002e9ef47188e0bae95af3f596596e6062 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -618,7 +618,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, - {1, 0}); + {1, 0}) + .ToProto(); Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) @@ -767,7 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, - {2, 3, 0, 1}); + {2, 3, 0, 1}) + .ToProto(); Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 7e1f4aa0eb4801876d9bdbac6a4d7f1d09f81ba8..32de0fdf78f9c442e17c55e1b951e39122dac5ef 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -129,6 +129,42 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_InversePermutation) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + permutation = s32[3,4] parameter(0) + reshape = s32[3,4,1] reshape(permutation) + operand = s32[3,4] iota(), iota_dimension=1 + updates = s32[3,4,1,1] iota(), iota_dimension=1 + iota = s32[3,4,1] iota(), iota_dimension=0 + indices = s32[3,4,2] concatenate(iota, reshape), dimensions={2} + ROOT scatter = s32[3,4] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={2,3}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=2 +} +)"; + Literal permutation = + LiteralUtil::CreateR2({{1, 3, 2, 0}, {3, 0, 2, 1}, {2, 3, 1, 0}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + auto actual = ExecuteAndTransfer(std::move(module), {&permutation}); + Literal expected = + LiteralUtil::CreateR2({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); +} + XLA_TEST_F(ScatterTest, SimpleR4) { const char* hlo_text = R"( HloModule SimpleR4 diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2f18036ff4c5b0bfa28723fb181c33fa6995eb80..eafa48ed7b8cf2bd67fe767ad36082661dbbd66e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -28,65 +29,113 @@ namespace xla { namespace { template -void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } +} + +template +void PopulateWithIntNext(Literal* literal); + +template <> +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + uint16 next_value = 0; + for (half& value : literal->data()) { + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.x = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; + } +} + +template <> +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + // Start at 0x80 rather than 0 to avoid denormals. + uint16 next_value = 0x80; + for (bfloat16& value : literal->data()) { + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.value = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; + } +} + +template +void PopulateWithNextAfter(Literal* literal) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + float next_value = std::numeric_limits::min(); + for (float& value : literal->data()) { + value = next_value; + next_value = std::nextafter(next_value, std::numeric_limits::max()); + } +} + +template ::value || + std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithIntNext(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template ::value && + !std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithNextAfter(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template +void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates) { - // Duplicates may be generated if the number of elements in the literal - // exceeds the number of positive values supported by the type. - FloatT next_value = std::numeric_limits::min(); - for (FloatT& value : literal->data()) { - value = next_value; - next_value = - std::nextafter(next_value, std::numeric_limits::max()); - } - std::shuffle(literal->data().begin(), literal->data().end(), - *engine); + PopulateWithNoDuplicateData(literal, engine); } else { - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (FloatT& value : literal->data()) { - value = static_cast(generator(*engine)); - } + PopulateWithRandomFloatingPointData(literal, engine); } } -template -void PopulateWithRandomFloatingPointData(Literal* literal, +template <> +void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine, - no_duplicates); -} - -template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for half types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. - CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (half& value : literal->data()) { - value = static_cast(generator(*engine)); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); } } template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for bfloat types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. +void PopulateWithFloatingPointData(Literal* literal, + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (bfloat16& value : literal->data()) { - value = static_cast(generator(*engine)); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); } } @@ -135,20 +184,16 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case S8: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e066b3f4f224e80dab1b69c12fe76855d2967401..e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -175,5 +175,28 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } } +XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort, is_scheduled=true + +ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { + %parameter.0 = bf16[2,1452]{1,0} parameter(0) + %parameter.1 = s32[2,1452]{1,0} parameter(1) + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = args[0]; + + absl::flat_hash_set key_set; + for (const bfloat16& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(absl::bit_cast(value)).second); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index a2b7c26331b3cc89ed0413efe8eb31c2b9e37038..601c6b06938fef1f1ae809b33209ae59b24c70a2 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -108,26 +109,6 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewUnverifiedModule(); - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateAfterAll({param})); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - module->AddEntryComputation(builder.Build()); - - Status status = - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) - .Run(module.get()) - .status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr( - "Operands of token instructions must be TOKEN types")); -} - XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a // AfterAll instruction in the while body. @@ -220,5 +201,95 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { } } +XLA_TEST_F(TokenHloTest, AddDependency) { + string module_string = R"( +HloModule AddDependency, is_scheduled=true + +// Computes (p0 + 42) * (-p1) +// where there is a dependency from the add to the negation using a token +// with after-all and add-dependency instructions. +ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + + %forty_two = f32[] constant(42.0) + %add = f32[] add(f32[] %p0, f32[] %forty_two) + %token = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %neg = f32[] negate(f32[] %p1_after_token) + ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto p1 = LiteralUtil::CreateR0(3.0); + auto expected = LiteralUtil::CreateR0(-156.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) { + string module_string = R"( +HloModule AddDependencyOfConstant, is_scheduled=true + +ENTRY %AddDependency (p0: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %forty_two = f32[] constant(42.0) + %token = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto expected = LiteralUtil::CreateR0(420.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) { + string module_string = R"( +HloModule AddDependencyAsRoot, is_scheduled=true +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %neg = f32[3] negate(f32[3] %p) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto input = LiteralUtil::CreateR1({1.0, 3.0, 7.0}); + auto expected = LiteralUtil::CreateR1({-1.0, -3.0, -7.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input})); +} + +XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) { + string module_string = R"( +HloModule TupleShapedAddDependency, is_scheduled=true +ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { + %p0 = f32[3] parameter(0) + %p1 = f32[3] parameter(1) + %forty_two = f32[] constant(42.0) + %token = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 + %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 + ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR1({3.0, 3.0, 47.0}); + auto p1 = LiteralUtil::CreateR1({1.0, -2.0, 2.0}); + auto expected = LiteralUtil::CreateR1({2.0, 5.0, 45.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index ca036f1ae0d5e31a3f83d9d31c80e070c2a666df..e57d072a0632b492b8b6e34439f4e80332b843b6 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -157,10 +157,12 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + ExecutableBuildOptions build_options; + build_options.mutable_debug_options()->set_xla_hlo_profile(true); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions().set_hlo_profile(true))); + build_options)); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -208,7 +210,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape, rhs_shape); - + VLOG(4) << "Profile Output:\n" << profile_output; std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 47be9f5adf1063463d7678579a7f394684aaf357..ff2c3399928c0e6339304323c4f93e212933a340 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -82,13 +82,17 @@ struct Options { std::unique_ptr CompileExecutable(const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); - std::vector argument_layouts; - for (const auto& param : + std::vector argument_layouts; + argument_layouts.reserve( + computation.proto().host_program_shape().parameters_size()); + std::vector argument_layout_ptrs; + for (const ShapeProto& param : computation.proto().host_program_shape().parameters()) { - argument_layouts.push_back(¶m); + argument_layouts.push_back(Shape(param)); + argument_layout_ptrs.push_back(&argument_layouts.back()); } return client - ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) .ValueOrDie(); } @@ -149,7 +153,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "--generate_fake_infeed only works if the model has 0 or 1 " "infeed ops, but this one has >= 2."; provide_infeed = true; - infeed_shape = instruction.shape(); + infeed_shape = Shape(instruction.shape()); LOG(INFO) << "Generating fake infeed shape for inferred shape: " << ShapeUtil::HumanString(infeed_shape); } @@ -315,9 +319,10 @@ int RealMain(absl::Span args, const Options& opts) { if (snapshot.has_result()) { Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal.ToString().c_str()); + fprintf( + stdout, "was %s:%s\n", + ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(), + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 8ce741647414a1fa75e6d706ec1e719ace7b7cc8..6722641e9d2c177440361e6f0d1f6c0804eb7cda 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -152,6 +152,13 @@ static inline absl::Span AsInt64Slice( slice.size()); } +// TODO(b/29771030): This nop overload was added to simplify the migration of +// Shape from a proto to a C++ class. Remove after class has been migrated. +static inline absl::Span AsInt64Slice( + absl::Span slice) { + return slice; +} + // As above, but for uint64 types. static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { @@ -387,6 +394,19 @@ T CeilOfRatio(T dividend, T divisor) { return tensorflow::MathUtil::CeilOfRatio(dividend, divisor); } +template +std::vector ElementWiseCeilOfRatio(absl::Span dividends, + absl::Span divisors) { + std::vector ceil_of_ratios; + CHECK_EQ(dividends.size(), divisors.size()); + ceil_of_ratios.reserve(dividends.size()); + absl::c_transform(dividends, divisors, std::back_inserter(ceil_of_ratios), + [](const T dividend, const T divisor) { + return CeilOfRatio(dividend, divisor); + }); + return ceil_of_ratios; +} + // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 8ea8dbab2574ca1e24271e7c1c7762d4a6b6a8de..51c73b3d17e4c32d9a8a14d3055ab56f02922af3 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -185,6 +185,17 @@ bool HasWindowReversal(const Window& window) { return false; } +bool AllOrNoneReversed(const Window& window) { + if (window.dimensions().empty()) { + return true; + } + bool reversed = window.dimensions()[0].window_reversal(); + return std::all_of(window.dimensions().begin(), window.dimensions().end(), + [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 1fb9e855fc16f334eb0e83dfd27b307b2149628f..099d7ecdd5c732ffc8c6ff6370288a2fc4144fa2 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -56,6 +56,7 @@ bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); bool HasWindowReversal(const Window& window); +bool AllOrNoneReversed(const Window& window); // Returns true if the given logical dimension is inactive in the sense that it // has window bound 1, no striding and no padding. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 28df3b03f398841460189910bc3a5096dfb0d367..a37eac7fe441d91aa71e1b6fd7b84099fee2215b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -193,7 +193,11 @@ message DebugOptions { // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; - bool xla_gpu_enable_fast_math = 100; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + bool xla_gpu_enable_fast_min_max = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results @@ -209,6 +213,9 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -224,7 +231,7 @@ message ExecutionOptions { // may be faster when using this layout. // // We use a Shape here to accommodate computations that return a tuple. - Shape shape_with_output_layout = 2; + ShapeProto shape_with_output_layout = 2; // Used to seed random-number generators used in this computation. If this is // 0, we generate a seed ourselves. @@ -253,7 +260,7 @@ message TransferToClientRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 2; + ShapeProto shape_with_layout = 2; } message TransferToClientResponse { @@ -281,7 +288,7 @@ message TransferToInfeedResponse { message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 1; + ShapeProto shape_with_layout = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; @@ -332,7 +339,7 @@ message CompileRequest { // The layouts of the input arguments. If not set, the default layout will be // used. Although the real arguments are not needed in compilation, the // layouts of the arguments can affect the compilation. - repeated Shape input_shape_with_layout = 3; + repeated ShapeProto input_shape_with_layout = 3; } message CompileResponse { @@ -406,7 +413,7 @@ message LoadDataRequest { string columnio_field = 2; // Individual element shape, excluding rows. - Shape element_shape = 3; + ShapeProto element_shape = 3; // Warning: ColumnIO does not support random-access, so use offset with // caution in performance-critical scenarios. @@ -422,7 +429,7 @@ message LoadDataRequest { message LoadDataResponse { GlobalDataHandle data = 1; - Shape data_shape = 2; + ShapeProto data_shape = 2; int64 available_rows = 3; int64 rows_loaded = 4; int64 nanoseconds = 5; @@ -433,7 +440,7 @@ message GetShapeRequest { } message GetShapeResponse { - Shape shape = 1; + ShapeProto shape = 1; } message UnpackRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 683ccc40f162ead3a248aee83d9abf3086a1ac93..85ec83437a10d973687a7fb84285c2e2541a53c7 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -108,6 +108,16 @@ enum Format { SPARSE = 2; } +// Describes a tile used in tiling-based layout. Refer to +// g3doc/layout_with_tiling.md for details about tiling-based layout. +message Tile { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape. // @@ -138,6 +148,20 @@ message Layout { // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated Tile tiles = 6; + + // Bit size of each element. If the size is bigger than what the element + // type requires, the value is stored in the least significant + // bits and the additional most significant bits are filled with 0's. + // + // TODO(b/119839262): implement in each backend or add Unimplemented error. + int64 element_size_in_bits = 7; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and // LayoutUtil::Hash appropriately to account for the new field. } @@ -154,7 +178,7 @@ message Layout { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Shape { +message ShapeProto { reserved 1; reserved "rank"; @@ -169,7 +193,7 @@ message Shape { repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. - repeated Shape tuple_shapes = 4; + repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. Layout layout = 5; @@ -183,9 +207,9 @@ message Shape { // Shape of the parameters and output of a computation (like a traditional // function signature). -message ProgramShape { - repeated Shape parameters = 1; - Shape result = 2; +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; repeated string parameter_names = 3; } @@ -320,7 +344,7 @@ message DeviceAssignmentProto { // Transfers to/from the client are encoded in literal form, and the structure // of the repeated fields is implied by the shape. message LiteralProto { - Shape shape = 1; + ShapeProto shape = 1; repeated bool preds = 2; bytes s8s = 15; bytes u8s = 3; @@ -521,7 +545,7 @@ message OpSharding { } Type type = 1; // The shape of the sharded tile. - Shape tile_shape = 2; + ShapeProto tile_shape = 2; // The shape of the tile assignment tensor - this must be the same rank as // tile_shape and the product of its dimensions must equal // tile_assignment_devices.size(). diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2ff97914f862e0ec30fc54602ec5fee2a0a5ebca..2dae746d034a1bf52e84de74dfb0c6e23aaed4d1 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -22,6 +22,7 @@ xla_proto_library( deps = [ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo_proto", ], ) @@ -32,20 +33,25 @@ cc_library( "xrt_compilation_cache.cc", "xrt_device.cc", "xrt_state.cc", + "xrt_util.cc", ], hdrs = [ "xrt_compilation_cache.h", "xrt_device.h", "xrt_state.h", + "xrt_util.h", ], deps = [ "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:device_memory_allocator", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index dc62cf7a6b24e373374b458d2e4722e79500fb93..2ccdf0f02d840600d5e0649c4805e3672d4a1286 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" +#include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -108,19 +109,26 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, TF_ASSIGN_OR_RETURN(xla::XlaComputation computation, client->LoadSnapshot(computation_proto.hlo_snapshot())); - std::vector argument_layouts( + std::vector argument_layouts( + config.program_shape().parameters_size()); + std::vector argument_layout_ptrs( config.program_shape().parameters_size()); for (int i = 0; i < config.program_shape().parameters_size(); ++i) { - argument_layouts[i] = &config.program_shape().parameters(i); + argument_layouts[i] = xla::Shape(config.program_shape().parameters(i)); + argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client->default_device_ordinal()); - build_options.set_result_layout(config.program_shape().result()); + build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); + if (config.has_debug_options()) { + *build_options.mutable_debug_options() = + BuildXlaDebugOptions(config.debug_options()); + } VLOG(1) << "Building executable"; auto compile_result = - client->Compile(computation, argument_layouts, build_options); + client->Compile(computation, argument_layout_ptrs, build_options); if (!compile_result.ok()) { return compile_result.status(); } @@ -174,11 +182,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, handle_output); xla::LocalExecutable* executable = entry->get().get_executable(); - xla::ProgramShape program_shape = executable->executable() - ->module() - .config() - .entry_computation_layout() - .ComputeProgramShape(); + xla::ProgramShapeProto program_shape = executable->executable() + ->module() + .config() + .entry_computation_layout() + .ComputeProgramShape() + .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); program_shape_output.vec()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 8c6191ddc06ea7d85f5fd21a7d4058c669ffdeb2..751329eefc33f3372335c805233dafabbf42bf36 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -228,14 +228,35 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); - - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - + if (config_proto.return_exploded_tuple() && + xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d43788b876a51866dc8a6611e8c734..3258286c10665225aab917107ffa614459c53f3d 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adcd8ef1f8f1bee52d210d558801afea..26a58fa42d8b730b365b11d2e5608e9945497763 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343f229097b557d33ad41bf9612b0696..a3d63106fa14674a9f5887ccfd908ce17dbc6384 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 25464b5554d21f4b936f3f4a442fd174a8b56a8b..abaa17e50e3f5e47a45f5a8a45fa2090d3efee39 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -102,7 +102,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -175,6 +175,18 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); @@ -203,6 +215,56 @@ xla::ProgramShape XlaCompiledProgramShape( ->ComputeProgramShape(); } +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -375,9 +437,12 @@ TEST(RawApiTest, CompileAndExecute) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -411,7 +476,7 @@ TEST(RawApiTest, CompileAndExecute) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -427,9 +492,12 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -465,7 +533,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -494,8 +562,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = param_shape; - *shapes->mutable_result() = result_shape; + *shapes->add_parameters() = param_shape.ToProto(); + *shapes->mutable_result() = result_shape.ToProto(); StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -510,8 +578,9 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {c_handle.program_shape}, {release}, &outputs)); - xla::ProgramShape program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec()(0))); + xla::ProgramShapeProto program_shape_proto; + EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); + xla::ProgramShape program_shape(program_shape_proto); EXPECT_EQ(program_shape.parameters_size(), 1); VLOG(2) << "Param: " @@ -520,7 +589,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); xla::ProgramShape xla_program_shape = - XlaCompiledProgramShape(xla_computation, *shapes); + XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); EXPECT_TRUE(xla::LayoutUtil::Equal( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) @@ -547,11 +616,11 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto(); *shapes->add_parameters() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); *shapes->mutable_result() = - xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}); + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -592,7 +661,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); xrt::XRTExecutionConfig e; e.set_release_input_handles(true); @@ -632,10 +701,13 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -671,14 +743,81 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + TEST(RawApiTest, LeakCompilationReference) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); @@ -703,9 +842,9 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -739,11 +878,11 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { auto expected = xla::LiteralUtil::CreateR0(15123899); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); - EXPECT_TRUE( - xla::ShapeUtil::HasPrimitiveType(program_shape.result(), xla::S64)); + EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( + xla::Shape(program_shape.result()), xla::S64)); } } // namespace diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 6ab77fbaaf0cbe23503ebc71775f52af01e41a74..378bb9246f27b8106310d565435404d7ac260a87 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package xrt; import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/compiler/xla/xla.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; @@ -36,16 +37,18 @@ message XLAComputationConfig { tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; // The arg/result shapes for the whole computation. - xla.ProgramShape program_shape = 4; + xla.ProgramShapeProto program_shape = 4; // The arg/result shapes for each core of a model-parallel // computation. per_core_args_and_result_shapes is optional for a // single-core computation. - repeated xla.ProgramShape per_core_program_shape = 5; + repeated xla.ProgramShapeProto per_core_program_shape = 5; // Describes how replicated computation instances should be assigned to // devices. There are num_cores_per_replica computations, and each one will be // sent and executed to the set of replica device numbers described in the // DeviceAssignment proto. DeviceAssignment device_assignment = 6; + // The debugging options to be passed to the XLA compilation process. + xla.DebugOptions debug_options = 7; } // Options and XLA computation for a compilation. @@ -98,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 3a99820d7aa9e9546cc95385fd98c05f28988e9e..31603e044d17baa3ae0ae583f61837811bb12495 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -41,6 +43,34 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; int64 get_uid() { @@ -48,6 +78,11 @@ int64 get_uid() { return static_cast(unsigned_rand); } +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; +} + Status AllocateScopedShapedBuffer( xla::Backend* backend, int device_ordinal, const xla::Shape& shape, std::unique_ptr* buffer) { @@ -100,9 +135,19 @@ XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, xla::DeviceMemoryAllocator* allocator) : allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -183,6 +228,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f781343fe6793af7ad28232fbfc184..3664c0cd4e6ad26945ae1012208fdb006164a066 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ef8bedc7324696cd255c72a851f0f2410e03848 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/xrt_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { + +bool DebugOptionsPassThroughEnabled() { + const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH"); + bool enabled = + env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + if (enabled) { + LOG(WARNING) << "Passing through XLA debug options!"; + } else { + LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options " + "will be retained"; + } + return enabled; +} + +string SafeDebugPath(const string& path) { + if (path.empty() || path.compare(0, 5, "gs://") == 0 || + path.compare(0, 11, "bigstore://") == 0) { + return path; + } + LOG(WARNING) << "Invalid config path (will be dropped): " << path; + return string(); +} + +} // namespace + +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { + static const bool options_passthrough = DebugOptionsPassThroughEnabled(); + if (options_passthrough) { + return ref_options; + } + xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); + options.set_xla_generate_hlo_text_to( + SafeDebugPath(ref_options.xla_generate_hlo_text_to())); + options.set_xla_dump_optimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to())); + options.set_xla_dump_computations_to( + SafeDebugPath(ref_options.xla_dump_computations_to())); + options.set_xla_dump_executions_to( + SafeDebugPath(ref_options.xla_dump_executions_to())); + for (auto& pass : ref_options.xla_disable_hlo_passes()) { + options.add_xla_disable_hlo_passes(pass); + } + options.set_xla_dump_unoptimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to())); + options.set_xla_dump_per_pass_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to())); + return options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h new file mode 100644 index 0000000000000000000000000000000000000000..d9c05a7f3406313f99ae214d67b34e8e7de8be3e --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions in support of the XRT API. + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace tensorflow { + +// Filters the debug options provided as argument according to the value of the +// TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is +// set to "1" or "true", the debug options will be returned as is. Otherwise +// only a subset of them will be set in the returned ones, and all the paths +// contained in it, will be limited to gs:// and bigstore:// ones. +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index a513aa1e7c49d64a860c740fffde156fb5bcbcf3..f6c6560c1c354ed8a36b98b1f564835eb9958e55 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -9,8 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_py_test") - py_library( name = "all_reduce_py", srcs = ["__init__.py"], @@ -29,29 +27,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - -tf_py_test( - name = "all_reduce_test", - srcs = ["python/all_reduce_test.py"], - additional_deps = [ - ":all_reduce", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:state_ops", + "//tensorflow/python/distribute:all_reduce", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 25f4b4b8d341331db79321338a88cabfe325eea5..238cdaf8a79812df3f043d9d070bbcfd443f6e1e 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,842 +18,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import math - -from tensorflow.python.framework import device as device_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nccl_ops - - -def _flatten_tensors(tensors): - """Check tensors for isomorphism and flatten. - - Args: - tensors: list of T `tf.Tensor` which must all have the same shape. - - Returns: - tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors - shape: the original shape of each element of input tensors - - Raises: - ValueError: tensors are empty or non-isomorphic or have unknown shape. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - for tensor in tensors: - shape = shape.merge_with(tensor.shape) - if not shape.is_fully_defined(): - raise ValueError("Tensors must have statically known shape.") - if len(shape) != 1: - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, [-1])) - tensors = reshaped - return tensors, shape - - -def _reshape_tensors(tensors, shape): - """Reshape tensors flattened by _flatten_tensors. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - shape: list of integers describing the desired shape. Product of - the elements must equal the length of each tensor. - - Returns: - list of T `tf.Tensor` which are the reshaped inputs. - """ - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, shape)) - return reshaped - - -def _padded_split(tensor, pieces): - """Like split for 1D tensors but pads-out case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - thin input tensor, in order. The final tensor may - be zero-padded on the end to make its size equal to those of all - of the other tensors. - - Raises: - ValueError: The input tensor is not 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - with ops.colocate_with(tensor): - if tensor_len % pieces != 0: - # pad to an even length - chunk_size = 1 + tensor_len // pieces - if pieces > tensor_len: - # This is an edge case that should not come up in practice, - # i.e. a different reduction algorithm would be better, - # but we'll make it work just for completeness. - pad_len = pieces - tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - elif (pieces - 1) * chunk_size >= tensor_len: - # Another edge case of limited real interest. - pad_len = (pieces * chunk_size) % tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - else: - last_chunk_size = tensor_len - (pieces - 1) * chunk_size - pad_len = chunk_size - last_chunk_size - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - parts = array_ops.split(tensor, piece_lens) - parts[-1] = array_ops.concat( - [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - return parts, pad_len - else: - return array_ops.split(tensor, pieces), 0 - - -def _strip_padding(tensors, pad_len): - """Strip the suffix padding added by _padded_split. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - pad_len: number of elements to be stripped from the end of each tensor. - - Returns: - list of T `tf.Tensor` which are the stripped inputs. - - Raises: - ValueError: tensors must be a non-empty list of 1D tensors, and - each must be longer than pad_len. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - if len(shape) > 1: - raise ValueError("tensors must be 1D") - prefix_len = int(shape[0] - pad_len) - if prefix_len < 0: - raise ValueError("pad_len longer than tensor") - stripped = [] - for t in tensors: - with ops.colocate_with(t): - stripped.append(array_ops.slice(t, [0], [prefix_len])) - return stripped - - -def _ragged_split(tensor, pieces): - """Like split for 1D tensors but allows case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - the input tensor, in order. The final tensor may be shorter - than the others, which will all be of equal length. - - Raises: - ValueError: input tensor must be 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - chunk_size = tensor_len // pieces - with ops.colocate_with(tensor): - if tensor_len != (pieces * chunk_size): - # last piece will be short - assert pieces > 1 - last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) - assert last_chunk_size > 0 - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - return array_ops.split(tensor, piece_lens) - else: - return array_ops.split(tensor, pieces) - - -def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for each subchunk. - - In the basic ring reduction algorithm there are size(T)/num_devices - data chunks and each device process one chunk per tick, i.e. sending - one chunk and receiving one chunk. The idea of subchunking is that - each device processes num_subchunks smaller data regions per tick, - and the ring rank permutation is different for each subchunk index - so that a device is potentially sending to and receiving from - num_subchunks different other devices at each tick. Where multiple - independent data channels exist between devices, this strategy - supplies a method of using them in parallel. - - Args: - num_workers: number of worker tasks - num_subchunks: number of subchunks into which to divide each per-GPU chunk. - gpu_perm: an array of integers in [0, num_gpus-1] giving the default - ring order of GPUs at each worker. Other permutations will be generated - by rotating this array and splicing together per-worker instances. - - Raises: - ValueError: the number of subchunks may not exceed the number of GPUs. - - Returns: - pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - preceding device in the permutation for that subchunk. The - device index of GPU i at worker j is i + (j * num_gpus). - rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - local rank of device d in the permutation for that subchunk. - """ - num_gpus = len(gpu_perm) - devices = num_workers * num_gpus - if devices == 0: - return [], [] - if num_subchunks > num_gpus: - raise ValueError( - "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) - rotation_interval = max(1, int(num_gpus / num_subchunks)) - perms_by_s = [] - for s in range(0, num_subchunks): - full_order = [] - offset = s * rotation_interval - for w in range(0, num_workers): - default_order = [(w * num_gpus) + i for i in gpu_perm] - dev_order = default_order[offset:] + default_order[:offset] - full_order += dev_order - perms_by_s.append(full_order) - pred_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - rank_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - for s in range(0, num_subchunks): - for d in range(0, devices): - for t in range(0, devices): - if d == perms_by_s[s][t]: - rank_by_s_d[s][d] = t - pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] - break - return (pred_by_s_d, rank_by_s_d) - - -def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, - gpu_perm, red_op, un_op=None): - """Construct a subgraph performing a ring-style all-reduce of input_tensors. - - Args: - input_tensors: a list of T `tf.Tensor` objects, which must all - have the same shape and type. - num_workers: number of worker tasks spanned by input_tensors. - num_subchunks: number of subchunks each device should process in one tick. - gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at - each worker. All workers must have the same number of - GPUs with the same rank ordering. If NVLINK is available, this should - be a ring order supported by NVLINK edges. - red_op: a binary operator for elementwise reduction. - un_op: an optional unary operator to apply to fully reduced values. - - Raises: - ValueError: empty input_tensors or they don't all have same - size. - - Returns: - a list of T `tf.Tensor` identical sum-reductions of input_tensors. - """ - if len(input_tensors) < 2: - raise ValueError("input_tensors must be length 2 or longer") - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - (pred_by_s_d, rank_by_s_d) = _ring_permutations( - num_workers, num_subchunks, gpu_perm) - chunks_by_dev, pad_len = _build_ring_gather( - input_tensors, devices, - num_subchunks, pred_by_s_d, rank_by_s_d, red_op) - if un_op: - chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) - output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev) - if pad_len > 0: - output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_ring_gather(input_tensors, devices, num_subchunks, - pred_by_s_d, rank_by_s_d, red_op): - """Construct a subgraph for the first (reduction) pass of ring all-reduce. - - Args: - input_tensors: a list of T `tf.Tensor` 1D input tensors of same - shape and type. - devices: array of device name strings - num_subchunks: number of subchunks each device should process in one tick. - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - red_op: a binary operator for elementwise reduction - - Raises: - ValueError: tensors must all be one dimensional. - - Returns: - list of list of T `tf.Tensor` of (partially) reduced values where - exactly num_subchunks chunks at each device are fully reduced. - """ - num_devices = len(input_tensors) - if num_devices == 0: - return [] - if num_devices == 1: - return input_tensors - shape = input_tensors[0].shape - if 1 != len(shape): - raise ValueError("input tensors must be 1D") - num_chunks = num_devices * num_subchunks - num_ticks = num_devices - 1 - # Initialize chunks_by_dev with splits of the input tensors. - chunks_by_dev = [] - split_pad_len = 0 - for d in range(0, num_devices): - with ops.device(devices[d]): - splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) - chunks_by_dev.append(splits) - # Reduction phase - for tick in range(0, num_ticks): - # One new partial reduction for every chunk - new_partial_reductions = [None for _ in range(0, num_chunks)] - # Compute reductions with respect to last tick's values - for d in range(0, num_devices): - with ops.device(devices[d]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - new_partial_reductions[chunk_index] = red_op( - chunks_by_dev[pred_dev][chunk_index], - chunks_by_dev[d][chunk_index]) - # Update chunks_by_dev with the new values at the end of the tick. - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] - return chunks_by_dev, split_pad_len - - -def _apply_unary_to_chunks(f, chunks_by_dev): - """Apply a unary op to each tensor in chunks_by_dev, on same device. - - Args: - f: a unary function over T `tf.Tensor`. - chunks_by_dev: list of lists of T `tf.Tensor`. - - Returns: - new list of lists of T `tf.Tensor` with the same structure as - chunks_by_dev containing the derived tensors. - """ - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append([f(t) for t in x]) - return output - - -def _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev): - """Construct subgraph for second (scatter) pass of ring all-reduce. - - Args: - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T `tf.Tensor` indexed by ints - (device, chunk) - - Raises: - ValueError: chunks_by_dev is not well-formed - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device corresponding to the outer dimension of chunks_by_dev. - """ - num_devices = len(chunks_by_dev) - num_chunks = len(chunks_by_dev[0]) - if 0 != num_chunks % num_devices: - raise ValueError( - "Expect number of chunks per device to be divisible by num_devices") - num_subchunks = int(num_chunks / num_devices) - num_ticks = num_devices - 1 - for tick in range(0, num_ticks): - passed_values = [None for _ in range(0, num_chunks)] - for d in range(0, num_devices): - with ops.colocate_with(chunks_by_dev[d][0]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - passed_values[chunk_index] = array_ops.identity( - chunks_by_dev[pred_dev][chunk_index]) - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = passed_values[chunk_index] - # Join chunks at each device. - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append(array_ops.concat(x, 0)) - return output - - -def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): - """Construct a subgraph for recursive halving-doubling all-reduce. - - The recursive halving-doubling algorithm is described in - http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf - - The concept is to arrange the participating n devices in - a linear sequence where devices exchange data pairwise - with one other device in each round. During the gather - phase there are lg(n) rounds where devices exchange - increasingly smaller sub-tensors with another device - at increasingly greater distances, until at the top - each device has 1/n of the fully reduced values. During the - scatter phase each device exchanges its fully reduced - sub-tensor (which doubles in length at each round) - with one other device at increasingly smaller distances - until each device has all of the fully reduced values. - - Note: this preliminary version requires that len(input_tensors) be a - power of 2. TODO(tucker): relax this restriction. Also, the - number of elements in each tensor must be divisible by 2^h where h - is the number of hops in each phase. This will also be relaxed in - the future with edge-case specific logic. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - red_op: a binary elementwise reduction Op. - un_op: an optional unary elementwise Op to apply to reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device of input_tensors. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - devices = [t.device for t in input_tensors] - input_tensors, shape = _flatten_tensors(input_tensors) - reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) - if un_op: - reduced_shards = [un_op(t) for t in reduced_shards] - output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_recursive_hd_gather(input_tensors, devices, red_op): - """Construct the gather phase of recursive halving-doubling all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - devices: a list of strings naming the devices hosting input_tensors, - which will also be used to host the (partial) reduction values. - red_op: a binary elementwise reduction Op. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensor shards. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - if num_devices != (2 ** num_hops): - raise ValueError("num_devices must be a power of 2") - chunks = input_tensors - for h in range(0, num_hops): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_dev = devices[d] - right_dev = devices[d + span] - left_split = array_ops.split(chunks[d], 2) - right_split = array_ops.split(chunks[d+span], 2) - with ops.device(left_dev): - new_chunks[d] = red_op(left_split[0], right_split[0]) - with ops.device(right_dev): - new_chunks[d + span] = red_op(left_split[1], right_split[1]) - chunks = new_chunks - return chunks - - -def _build_recursive_hd_scatter(input_tensors, devices): - """Construct the scatter phase of recursive halving-doublng all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` that are fully-reduced shards. - devices: a list of strings naming the devices on which the reconstituted - full tensors should be placed. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" - chunks = input_tensors - for h in reversed(range(0, num_hops)): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_idx = d - right_idx = d + span - left_dev = devices[left_idx] - right_dev = devices[right_idx] - with ops.device(left_dev): - new_chunks[left_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - with ops.device(right_dev): - new_chunks[right_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - chunks = new_chunks - return chunks - - -def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): - """Construct a subgraph for shuffle all-reduce. - - Shuffle reduce is essentially the algorithm implemented when using - parameter servers. Suppose tensor length is n, there are d devices - and g gather shards. Each device sends a n/g length sub-tensor to - each gather shard. The gather shards perform a reduction across d - fragments, then broadcast the result back to each device. The - devices then join the g fully reduced fragments they receive from - the shards. The gather shards could perform d-1 pairwise - reductions, or one d-way reduction. The first is better where - reduction Op time is low compared to transmission time, the second - better in the other case. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: an n-array elementwise reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - dst_devices = [t.device for t in input_tensors] - reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, - red_op, un_op) - output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): - """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: the binary reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced shards. - - Raises: - ValueError: inputs not well-formed. - """ - num_source_devices = len(input_tensors) - num_gather_devices = len(gather_devices) - shape = input_tensors[0].shape - if len(shape) != 1: - raise ValueError("input_tensors must be 1D") - shards_by_source = [] - for d in range(0, num_source_devices): - with ops.colocate_with(input_tensors[d]): - shards_by_source.append( - _ragged_split(input_tensors[d], num_gather_devices)) - reduced_shards = [] - for d in range(0, num_gather_devices): - with ops.device(gather_devices[d]): - values = [s[d] for s in shards_by_source] - red_shard = red_op(values) - if un_op: - red_shard = un_op(red_shard) - reduced_shards.append(red_shard) - return reduced_shards - - -def _build_shuffle_scatter(reduced_shards, dst_devices): - """Build the scatter phase of shuffle all-reduce. - - Args: - reduced_shards: list of T @(tf.Tensor} fully reduced shards - dst_devices: list of names of devices at which the fully-reduced value - should be reconstituted. - - Returns: - list of T `tf.Tensor` scattered tensors. - """ - num_devices = len(dst_devices) - out_tensors = [] - for d in range(0, num_devices): - with ops.device(dst_devices[d]): - out_tensors.append(array_ops.concat(reduced_shards, 0)) - return out_tensors - - -def _split_by_task(devices, values): - """Partition devices and values by common task. - - Args: - devices: list of device name strings - values: list of T `tf.tensor` of same length as devices. - - Returns: - (per_task_devices, per_task_values) where both values are - lists of lists with isomorphic structure: the outer list is - indexed by task, and the inner list has length of the number - of values belonging to that task. per_task_devices contains - the specific devices to which the values are local, and - per_task_values contains the corresponding values. - - Raises: - ValueError: devices must be same length as values. - """ - num_devices = len(devices) - if num_devices != len(values): - raise ValueError("len(devices) must equal len(values)") - per_task_devices = collections.OrderedDict() - per_task_values = collections.OrderedDict() - for d in range(num_devices): - d_spec = device_lib.DeviceSpec.from_string(devices[d]) - if not hasattr(d_spec, "task") or d_spec.task is None: - assert False, "failed to parse device %s" % devices[d] - index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) - if index not in per_task_devices: - per_task_devices[index] = [] - per_task_values[index] = [] - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - - return (list(per_task_devices.values()), list(per_task_values.values())) - - -def build_nccl_all_reduce(input_tensors, red_op, un_op=None): - """Build a subgraph that does one full all-reduce, using NCCL. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. Must be one of - {tf.add} - un_op: optional unary elementwise Op to apply to fully-reduce values. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: red_op not supported. - """ - if red_op == math_ops.add: - output_tensors = nccl_ops.all_sum(input_tensors) - else: - raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) - if un_op: - un_op_wrapped = [] - for t in output_tensors: - with ops.colocate_with(t): - un_op_wrapped.append(un_op(t)) - output_tensors = un_op_wrapped - return output_tensors - - -def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): - """Construct a subgraph for NCCL hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [None for w in range(0, num_workers)] - up_devices = up_values[:] - down_values = up_values[:] - # First stage: reduce within each worker using NCCL - for w in range(0, num_workers): - worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) - # NOTE: these reductions will not run to completion unless - # every output value is used. Since we only need one, we - # need to put control dependencies on the rest. - with ops.control_dependencies(worker_values): - with ops.device(worker_values[0].device): - up_values[w] = array_ops.identity(worker_values[0]) - up_devices[w] = per_worker_devices[w][0] - # Second stage: Apply upper_level_f to reduce across first device at - # each worker - level_2_output = upper_level_f(up_values) - # Third stage: propagate within each worker using NCCL Broadcast - for w in range(0, num_workers): - dst_tensors = [] - with ops.device(per_worker_devices[w][0]): - broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) - for d in per_worker_devices[w]: - with ops.device(d): - dst_tensors.append(array_ops.identity(broadcast_src)) - down_values[w] = dst_tensors - output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tensors has more than one element apply red_f, else apply un_op.""" - if len(input_tensors) > 1: - return red_f(input_tensors) - else: - if not un_op: - return input_tensors - output_tensors = [] - for t in input_tensors: - with ops.colocate_with(t): - output_tensors.append(un_op(t)) - return output_tensors - - -def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Ring across workers.""" - def upper_builder(y): - return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) - def upper_level_f(x): - return _reduce_non_singleton(x, upper_builder, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" - upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, - shuffle_red_op, un_op=None): - """Construct hybrid of NCCL within workers, Shuffle across workers.""" - upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices, - shuffle_red_op, un_op) - return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) - - -def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): - """Construct a subgraph for Shuffle hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - gather_devices: list of device names on which to host gather shards. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - # First stage, reduce across each worker using gather_devices. - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [] - if len(gather_devices) != num_workers: - raise ValueError("For shuffle hybrid, gather_devices must contain one " - "device per worker. ") - for w in range(0, num_workers): - reduced_shards = _build_shuffle_gather( - per_worker_values[w], [gather_devices[w]], red_op) - up_values.append(reduced_shards[0]) - # Second stage, apply upper_level_f. - level_2_output = upper_level_f(up_values) - # Third stage, apply shuffle scatter at each worker. - output_tensors = [] - for w in range(0, num_workers): - output_tensors += _build_shuffle_scatter( - [level_2_output[w]], per_worker_devices[w]) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Ring across workers.""" - def upper_builder(tensors): - return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, gather_devices, red_n_op, upper_level_f) - - -def build_shuffle_then_shuffle(input_tensors, first_gather_devices, - second_gather_devices, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Shuffle across workers.""" - def upper_builder(tensors): - return build_shuffle_all_reduce(tensors, second_gather_devices, - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, first_gather_devices, red_op, upper_level_f) +# pylint: disable=unused-import,wildcard-import +from tensorflow.python.distribute.all_reduce import * diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle index 17a57b99fd6c9efc09bda0ce1249b1f51bd5af5c..ddec08894f34f96b080610f1d27a6a436f7ffa91 100644 --- a/tensorflow/contrib/android/cmake/build.gradle +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -22,8 +22,8 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', - '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' } } } @@ -70,7 +70,7 @@ if (ndkDir == null || ndkDir == "") { ndkDir = System.getenv('ANDROID_NDK_HOME') } -if(! Os.isFamily(Os.FAMILY_WINDOWS)) { +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { // This script is for non-Windows OS. For Windows OS, MANUALLY build // (or copy the built) libs/headers to the // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6d2d70c99b4cc804f2c8bf57afdc8c11f1f73516 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark") + +py_library( + name = "benchmark_base", + srcs = [ + "benchmark_base.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cartpole_benchmark", + size = "enormous", + srcs = ["cartpole_benchmark.py"], + tags = [ + "local", + "manual", + "no_oss", + "notap", + "nozapfhahn", + ], + deps = [ + ":benchmark_base", + # Note: required gym dependency may need to be added here. + ], +) + +tf_py_logged_benchmark( + name = "cartpole_logged_benchmark", + target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark", +) diff --git a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py new file mode 100644 index 0000000000000000000000000000000000000000..93c694849c4dc3faca71e7f9d8614649a7784f99 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common benchmarking code. + +See https://www.tensorflow.org/community/benchmarks for usage. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +import tensorflow as tf + + +class ReportingBenchmark(tf.test.Benchmark): + """Base class for a benchmark that reports general performance metrics. + + Subclasses only need to call one of the _profile methods, and optionally + report_results. + """ + + def time_execution(self, name, target, iters, warm_up_iters=5): + for _ in range(warm_up_iters): + target() + + all_times = [] + for _ in range(iters): + iter_time = time.time() + target() + all_times.append(time.time() - iter_time) + + avg_time = np.average(all_times) + + extras = dict() + extras['all_times'] = all_times + + if isinstance(name, tuple): + extras['name'] = name + name = '_'.join(str(piece) for piece in name) + + self.report_benchmark( + iters=iters, wall_time=avg_time, name=name, extras=extras) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..4f553be58e94f11e45f0697558348fbbd26bfb91 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py @@ -0,0 +1,492 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A basic RL cartpole benchmark. + +The RL model uses the OpenAI Gym environment to train a simple network using +the policy gradients method. The training scales the gradients for each step +by the episode's cumulative discounted reward and averages these gradients over +a fixed number of games before applying the optimization step. + +For benchmarking purposes, we replace the OpenAI Gym environment to a fake +that returns random actions and rewards and never ends the episode. This way +the benchmarks compare the same amount of computation at each step. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import eager +from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base +from tensorflow.python import autograph as ag +from tensorflow.python.eager import context + +# +# AutoGraph implementation +# + + +@ag.convert() +def graph_append_discounted_rewards(destination, rewards, discount_rate): + """Discounts episode rewards and appends them to destination.""" + ag.set_element_type(rewards, tf.float32) + + cdr = 0.0 + reverse_discounted = [] + ag.set_element_type(reverse_discounted, tf.float32) + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + cdr.set_shape(()) + reverse_discounted.append(cdr) + + retval = destination + # Note: AutoGraph doesn't yet support reversed() so we use a loop instead. + for i in range(len(reverse_discounted) - 1, -1, -1): + retval.append(reverse_discounted[i]) + + return retval + + +class GraphPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(GraphPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + # TODO(mdan): Move this method out of the class. + @ag.convert() + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + var_list = tf.trainable_variables() + grad_list = [ + tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list + ] + + step_counts = [] + discounted_rewards = [] + ag.set_element_type(discounted_rewards, tf.float32) + ag.set_element_type(step_counts, tf.int32) + + # Note: we use a shared object, cart_pole_env here. Because calls to the + # object's method are made through py_func, TensorFlow cannot detect its + # data dependencies. Hence we must manually synchronize access to it + # and ensure the control dependencies are set in such a way that + # calls to reset(), take_one_step, etc. are made in the correct order. + sync_counter = tf.constant(0) + + for _ in tf.range(num_games): + with tf.control_dependencies([sync_counter]): + obs = cart_pole_env.reset() + with tf.control_dependencies([obs]): + sync_counter += 1 + + game_rewards = [] + ag.set_element_type(game_rewards, tf.float32) + + for step in tf.range(max_steps_per_game): + logits, actions = self(obs) # pylint:disable=not-callable + logits = tf.reshape(logits, ()) + actions = tf.reshape(actions, ()) + + labels = 1.0 - tf.cast(actions, tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + grads = tf.gradients(loss, var_list) + + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + with tf.control_dependencies([sync_counter]): + obs, reward, done = cart_pole_env.step(actions) + with tf.control_dependencies([obs]): + sync_counter += 1 + obs = tf.reshape(obs, (1, 4)) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = graph_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = ag.stack(discounted_rewards) + discounted_rewards.set_shape((None,)) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = ag.stack(grad_list[i]) + + # This block just adjusts the shapes to match for multiplication. + r = normalized_rewards + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return ag.stack(step_counts) + + +@ag.convert() +def graph_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + i = tf.constant(0) + mean_steps_per_iteration = [] + ag.set_element_type(mean_steps_per_iteration, tf.int32) + + while i < iterations: + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + i += 1 + + return ag.stack(mean_steps_per_iteration) + + +class GraphGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ()) + obs = tf.reshape(obs, (1, 4)) + obs = tf.cast(obs, tf.float32) + return obs + + def step(self, actions): + + def take_one_step(actions): + obs, reward, done, _ = self.env.step(actions) + obs = obs.astype(np.float32) + reward = np.float32(reward) + return obs, reward, done + + return ag.utils.wrap_py_func(take_one_step, + (tf.float32, tf.float32, tf.bool), (actions,)) + + +class GraphRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return tf.random.normal((1, 4)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = tf.random.normal((1, 4)) + fixed_reward = tf.constant(0.001) + done = tf.constant(False) + return random_obs, fixed_reward, done + + +# +# Eager implementation +# + + +def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate): + cdr = 0.0 + reverse_discounted = [] + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + reverse_discounted.append(cdr) + + discounted_rewards.extend(reversed(reverse_discounted)) + return discounted_rewards + + +class EagerPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(EagerPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + self._grad_fn = eager.implicit_gradients( + self._get_cross_entropy_and_save_actions) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + def _get_cross_entropy_and_save_actions(self, inputs): + logits, actions = self(inputs) # pylint:disable=not-callable + self._current_actions = actions + labels = 1.0 - tf.cast(actions, tf.float32) + return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + grad_list = None + + step_counts = [] + discounted_rewards = [] + + for _ in range(num_games): + obs = cart_pole_env.reset() + + game_rewards = [] + + for step in range(max_steps_per_game): + grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32)) + grads, var_list = zip(*grads_and_vars) + actions = self._current_actions.numpy()[0][0] + + if grad_list is None: + grad_list = [[g] for g in grads] + else: + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + obs, reward, done = cart_pole_env.step(actions) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = eager_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = tf.stack(discounted_rewards) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = tf.stack(grad_list[i]) + + r = normalized_rewards + while r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return tf.stack(step_counts) + + +def eager_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + mean_steps_per_iteration = [] + + for _ in range(iterations): + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + + return mean_steps_per_iteration + + +class EagerGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + return self.env.reset() + + def step(self, actions): + obs, reward, done, _ = self.env.step(actions) + return obs, reward, done + + +class EagerRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return np.random.normal(size=(4,)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = np.random.normal(size=(4,)) + fixed_reward = 0.001 + done = False + return random_obs, fixed_reward, done + + +def graph_demo_training(): + """Not used in the benchmark. Used to confirm a functional model.""" + with tf.Graph().as_default(): + tf.set_random_seed(0) + + network = GraphPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = GraphGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + train_ops = graph_train_model(network, env, opt, iterations=5) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + steps_per_iteration = sess.run(train_ops) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +def eager_demo_training(): + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = EagerGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + steps_per_iteration = eager_train_model(network, env, opt, iterations=5) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark): + """Actual benchmark. + + Trains the RL agent a fixed number of times, on random environments that + result in constant number of steps. + """ + + def benchmark_cartpole(self): + + def train_session(sess, ops): + return lambda: sess.run(ops) + + def train_eager(network, env, opt): + return lambda: eager_train_model(network, env, opt, iterations=10) + + for model_size in (10, 100, 1000): + with tf.Graph().as_default(): + network = GraphPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = GraphRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + train_ops = graph_train_model(network, env, opt, iterations=10) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + self.time_execution(('cartpole', 'autograph', model_size), + train_session(sess, train_ops), 20) + + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = EagerRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + self.time_execution(('cartpole', 'eager', model_size), + train_eager(network, env, opt), 20) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 55faad983f2bcf2f3fa633669bd371608e2e925b..3e4d0dc1cec76b068c1c846eb476eec615e4f613 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,8 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -101,12 +102,15 @@ def batch_function(num_batch_threads, def decorator(fn): # pylint: disable=missing-docstring def decorated(*args): # pylint: disable=missing-docstring - types = [arg.dtype for arg in args] - @function.Defun(*types) + @function.defun() def computation(*computation_args): return fn(*computation_args) + computation = computation.get_concrete_function( + *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) + for i, x in enumerate(args)]) + with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): @@ -123,7 +127,7 @@ def batch_function(num_batch_threads, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) + Tout=[o.dtype for o in computation.outputs]) return decorated diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 01ee8703a93836d607ee9b765c51c79fe3bb974f..9109b9c1c91cefa4c52bad49de23336a6e05e1ef 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -219,6 +219,7 @@ class BatchOpsTest(test.TestCase): @batch_ops.batch_function(1, 10, 100000) def computation(in_t): + self.assertTrue(in_t.shape is not None) return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 13215ffabf3a956d3f83697f867457b2fa72e7c9..8b6ed9f041b89a0da02a505ec261bca82b094f74 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -81,7 +81,7 @@ class ExpectationImportanceSampleTest(test.TestCase): # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): - x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1]) + x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler( diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 18d40fc1dff8e7c9aefffbe3ceba770598a42096..e83a54851195708eb7e6412b7400236f4bc06e6b 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -353,12 +353,12 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, def _sample_mean(values): """Mean over sample indices. In this module this is always [0].""" - return math_ops.reduce_mean(values, reduction_indices=[0]) + return math_ops.reduce_mean(values, axis=[0]) def _sample_max(values): """Max over sample indices. In this module this is always [0].""" - return math_ops.reduce_max(values, reduction_indices=[0]) + return math_ops.reduce_max(values, axis=[0]) def _get_samples(dist, z, n, seed): diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1955cc666273e97e6b2378766f13d2..79052bee35c7895cb4048b10c1f73acb036d1587 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index f083ce6f44b3c2a83d9b5d3235056eb94c4be4a8..e95dc577184f7e81d942755b41065f52131ce9f6 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -366,6 +366,39 @@ BigtableTestClient::MutateRows( return MakeUnique(request.entries_size()); } +std::unique_ptr> +BigtableTestClient::AsyncMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> +BigtableTestClient::AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> +BigtableTestClient::AsyncMutateRows( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index dac2b16a216d26f02684c7401ed2ddaa4b7baddb..c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -61,6 +61,25 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { MutateRows(grpc::ClientContext* context, google::bigtable::v2::MutateRowsRequest const& request) override; + std::unique_ptr> + AsyncMutateRow(grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> + AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> + AsyncMutateRows(::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6..197f5578eb010bee5a3aad7c05446393193f99e2 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 7c87b0daeb09950cc44c51f49c16534d413f0376..b6cdc7aab0320fe5f457288ada03a46e18a694cc 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -35,8 +35,8 @@ from tensorflow.contrib.util import loader from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader @@ -111,8 +111,7 @@ class BigtableClient(object): class BigtableTable(object): - """BigtableTable is the entrypoint for reading and writing data in Cloud - Bigtable. + """Entry point for reading and writing data in Cloud Bigtable. This BigtableTable class is the Python representation of the Cloud Bigtable table within TensorFlow. Methods on this class allow data to be read from and @@ -222,7 +221,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return _BigtablePrefixKeyDataset(self, prefix) + return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -234,7 +233,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return _BigtableSampleKeysDataset(self) + return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -279,7 +278,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, prefix, "", "", normalized, probability)) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,7 +324,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, "", start, end, normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, "", start, end, normalized, probability)) def parallel_scan_prefix(self, prefix, @@ -380,7 +381,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, prefix, "", "")) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -442,7 +444,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, "", start, end)) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -589,16 +592,8 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): self._table = table @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.TensorShape([]) - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) class _BigtablePrefixKeyDataset(_BigtableKeyDataset): @@ -658,16 +653,9 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): # pylint: disable=protected-access @@ -693,16 +681,9 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._num_outputs = len(normalized) + 1 # 1 for row key @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): return gen_bigtable_ops.bigtable_scan_dataset( @@ -726,16 +707,10 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._end = end @property - def output_classes(self): - return (ops.Tensor, ops.Tensor) - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return (dtypes.string, dtypes.string) + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 14b6fc4ac26f74f54628ae37ad6437c7d3e8caba..d3b23d949ee2c7674c3918d39e8b71d76eefcfec 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -132,6 +132,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":custom_loss_head", ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index a3df272e6924792128fc38fd153b9527b58b486e..b314b4d74df882a421d9a2ecce2629a63d5c5248 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -41,7 +41,8 @@ def make_custom_export_strategy(name, convert_fn, feature_columns, export_input_fn, - use_core_columns=False): + use_core_columns=False, + feature_engineering_fn=None): """Makes custom exporter of GTFlow tree format. Args: @@ -52,6 +53,7 @@ def make_custom_export_strategy(name, export_input_fn: A function that takes no arguments and returns an `InputFnOps`. use_core_columns: A boolean, whether core feature columns were used. + feature_engineering_fn: Feature eng function to be called on the input. Returns: An `ExportStrategy`. @@ -59,9 +61,12 @@ def make_custom_export_strategy(name, base_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() + features = input_fn.features + if feature_engineering_fn is not None: + features, _ = feature_engineering_fn(features, labels=None) (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns, use_core_columns) + features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index ca73e4af2fbd0a383d02fa7111f59161701661df..358404cd946bbc56d2f7228be8fe4223749c850b 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn -from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 38d19976ef38a295a172e935f70bdae3c67f01e2..a178820841c4c8bcb7f5742babdb6d0f4825de31 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -26,7 +28,8 @@ from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import losses as core_losses - +from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head +from tensorflow.python.ops import array_ops # ================== Old estimator interface=================================== # The estimators below were designed for old feature columns and old estimator @@ -414,6 +417,108 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None, + num_quantiles=100): + """Initializes a GradientBoostedDecisionTreeQuantileRegressor instance. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + """ + + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _quantile_regression_head(quantile): + # Use quantile regression. + head = custom_loss_head.CustomLossHead( + loss_fn=functools.partial( + losses.per_example_quantile_regression_loss, quantile=quantile), + link_fn=array_ops.identity, + logit_dimension=label_dimension) + return head + + learner_config.num_classes = max(2, label_dimension) + + super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=model.model_builder, + params={ + 'head': _quantile_regression_head(quantiles[0]), + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'logits_modifier_function': logits_modifier_function, + 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -437,12 +542,42 @@ def core_multiclass_head( # pylint:disable=protected-access head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( - n_classes=n_classes, loss_fn=loss_fn, loss_reduction=loss_reduction) + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) # pylint:enable=protected-access return head_fn +# For quantile regression, use this head with Core..Estimator, or use +# Core..QuantileRegressor directly, +def core_quantile_regression_head( + quantiles, + label_dimension=1, + weight_column=None, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + """Core head for quantile regression problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_quantile_regression_loss( + labels=labels, + predictions=logits, + weights=weight_column, + quantile=quantiles) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._regression_head( + label_dimension=label_dimension, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) + # pylint:enable=protected-access + return head_fn + + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): """An estimator using gradient boosted decision trees. @@ -606,3 +741,104 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): super(CoreGradientBoostedDecisionTreeRanker, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class CoreGradientBoostedDecisionTreeQuantileRegressor( + core_estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False, + num_quantiles=100): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. + """ + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': + core_quantile_regression_head( + quantiles[0], label_dimension=label_dimension), + 'feature_columns': + feature_columns, + 'learner_config': + learner_config, + 'num_trees': + num_trees, + 'weight_column_name': + weight_column_name, + 'examples_per_layer': + examples_per_layer, + 'center_bias': + center_bias, + 'logits_modifier_function': + logits_modifier_function, + 'use_core_libs': + True, + 'output_leaf_index': + output_leaf_index, + 'override_global_step_value': + None, + 'num_quantiles': + num_quantiles, + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index c155128c0e4ccf928349ee6453baff4384222096..ee052ac60387d8f993e4942dd7dff39e191dd3a4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,8 +48,8 @@ def _multiclass_train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) } - label = constant_op.constant( - [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + label = constant_op.constant([[1], [0], [0], [2], [2], [0], [1]], + dtype=dtypes.int32) return features, label @@ -77,6 +78,59 @@ def _infer_ranking_train_input_fn(): return features, None +_QUANTILE_REGRESSION_SIZE = 1000 + + +def _quantile_regression_input_fns(two_dimension=False): + # The data generation is taken from + # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html + np.random.seed(1) + + def f(x): + """The function to predict.""" + return x * np.sin(x) + + def g(x): + """The function to predict.""" + return x * np.cos(x) + + # Training data. + x = np.atleast_2d(np.random.uniform(0, 10.0, + size=_QUANTILE_REGRESSION_SIZE)).T + x = x.astype(np.float32) + + # Labels. + if not two_dimension: + y = f(x).ravel() + else: + y = np.column_stack((f(x).ravel(), g(x).ravel())) + + # Add random noise. + dy = 1.5 + 1.0 * np.random.random(y.shape) + noise = np.random.normal(0, dy) + y += noise + y_original = y.astype(np.float32) + if not two_dimension: + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + train_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=None, + shuffle=True) + + # Test on the training data to make sure the predictions are calibrated. + test_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=1, + shuffle=False) + + return train_input_fn, test_input_fn, y_original + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): @@ -341,6 +395,130 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): for prediction_dict in result_iter: self.assertTrue("classes" in prediction_dict) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -489,8 +667,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): feature_columns = [ core_feature_column.weighted_categorical_column( - categorical_column=core_feature_column. - categorical_column_with_vocabulary_list( + categorical_column=core_feature_column + .categorical_column_with_vocabulary_list( key="word", vocabulary_list=["the", "cat", "dog"]), weight_feature_key="weight") ] @@ -509,8 +687,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): # Weights for the words are 5 - cat, 6- dog and 1 -the. features_dict["word"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], - values=constant_op.constant( - ["the", "cat", "dog", "the"], dtype=dtypes.string), + values=constant_op.constant(["the", "cat", "dog", "the"], + dtype=dtypes.string), dense_shape=[4, 3]) features_dict["weight"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], @@ -534,6 +712,132 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 54c4ff059e3408d2cb8fc689a9ae877f57485f58..09b240a7006a8ef53eb95108b3adbfae728cf8fc 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -90,13 +90,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py index e04b56afbfd266dc13a5b0d78d171ea273415ee3..d640af354f55423b7c9706900359f5e64c459f39 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston_combined.py +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -80,13 +80,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 8edb5d6c640611bbb90d7731b2fea4354e125563..6d78e27e8f69ea289b686af8402bd91967f997f4 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -834,8 +834,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats *= normalizer_ratio; NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; + bool best_feature_updated = false; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); + CHECK(end_index - start_index >= 2) + << "Partition should have a non bias feature. Start index " + << start_index << " and end index " << end_index; + for (int64 feature_idx = start_index + 1; feature_idx < end_index; ++feature_idx) { GradientStats left_gradient_stats(*gradients_t, *hessians_t, @@ -845,11 +850,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats - left_gradient_stats; NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); - if (left_stats.gain + right_stats.gain > best_gain) { + if (!best_feature_updated || + left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; best_right_node_stats = right_stats; best_feature_idx = feature_idx; + best_feature_updated = true; } } SplitInfo split_info; @@ -864,7 +871,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { << feature_ids(best_feature_idx, 0) << ", " << feature_ids(best_feature_idx, 1) << "\nPartition IDS: " << partition_ids(start_index) << " " - << partition_ids(best_feature_idx); + << partition_ids(best_feature_idx) << " and best gain " << best_gain; equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 4da25298cb82093ac501997cc21c48265df06860..d26af58419752170bbc58bba757ac43349fc2cff 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -119,7 +119,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) def active_inputs(): diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index a2f708081a4b484d649b5d09b172c2c60db69aeb..386dc19fc7b9529993a9625fb1298f6eb9a70d87 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -36,9 +36,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -486,8 +486,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - values = array_ops.constant([], dtype=dtypes.int64) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) + values = constant_op.constant_v1([], dtype=dtypes.int64) gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 1fffbb5f660c681e1dde11a2aaf1d0f1cf79d1d0..0476bed2cd3f3ea5b47b10c51a819f17d6e37c74 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -605,7 +605,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantile_buckets, example_partition_ids, gradients, hessians, weights, empty_gradients, empty_hessians): """Updates the state for dense split handler.""" - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) quantile_values, quantile_weights = control_flow_ops.cond( is_active[1], # For the next layer, this handler is inactive. @@ -621,8 +621,8 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( @@ -708,11 +708,11 @@ def sparse_make_stats_update( def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant([0, 1], dtype=dtypes.int64), diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 74b0ea6989c65e83e7a466107d624712a0e72d1b..4a1b528646e7d2139d7eabb0264b8d280f8da133 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -39,9 +39,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -1476,9 +1476,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmpty(self): with self.cached_session() as sess: - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. - values = array_ops.constant([], dtype=dtypes.float32) + values = constant_op.constant_v1([], dtype=dtypes.float32) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) gradient_shape = tensor_shape.scalar() @@ -1549,8 +1549,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. - empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_indices = constant_op.constant_v1([], dtype=dtypes.int64, + shape=[0, 2]) + empty_values = constant_op.constant_v1([], dtype=dtypes.float32) empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, empty_values, [4, 2]) empty_sparse_column = empty_sparse_column.eval(session=sess) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index ab5713fbe26ab76eac923035e9feecc2ec51f492..9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -897,9 +897,9 @@ class GradientBoostedDecisionTreeModel(object): empty_hess_shape = [1] + self._hessian_shape.as_list() empty_grad_shape = [1] + self._gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) active_handlers = array_ops.unstack(active_handlers, axis=0) @@ -1257,13 +1257,12 @@ class GradientBoostedDecisionTreeModel(object): def _get_replica_device_setter(self, worker_device): """Creates a replica device setter.""" ps_tasks = self._num_ps_replicas - ps_ops = [ - "Variable", - "VariableV2", + ps_ops = list(device_setter.STANDARD_PS_OPS) + ps_ops.extend([ "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - ] + ]) ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( worker_device=worker_device, diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index b5ebaf1999519f65110e8164fa20bace5ecc3ef6..220e981618b7c0bfb1e4e98c087d83b451b9b3cf 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -48,6 +48,47 @@ def per_example_logistic_loss(labels, weights, predictions): labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() +# MUST USE WITH HESSIAN REGULARIZATION, +# This loss can have zero hessian, so it must be used with l2 or min_node_weight +# regularization. +# An example config is +# learner_config.constraints.min_node_weight = 1 / num_examples_per_layer +# learner_config.regularization.l2 = 1.0 / num_examples_per_layer +# TODO(nponomareva): make it multidimensional so we can estimate several +# quantiles at once. +def per_example_quantile_regression_loss(labels, weights, predictions, + quantile): + """Smoothed loss for quantile regression. + + The standard quantile regression loss is quantile*(y-y') when y>y' and + (quantile-1)*(y-y') otherwise, y' is a prediction, y is a label. The impl + below is this loss but squared in the region where the loss value < 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + quantile: The quantile to use. + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example quantile loss. + update_op: An update operation to update the loss's internal state. + """ + labels = math_ops.to_float(labels) + error = labels - predictions + square_loss_right = array_ops.where(error * quantile < 1.0, + math_ops.square(quantile * error), + quantile * error) + square_loss_left = array_ops.where(error * (quantile - 1) < 1, + math_ops.square((quantile - 1) * error), + (quantile - 1) * error) + + unweighted_loss = array_ops.where(error > 0, square_loss_right, + square_loss_left) + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() # This is classical form of Maximum entropy loss, that is twice differentiable # (sparse_softmax_cross_entropy which is what we go for is not twice @@ -78,8 +119,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): labels = array_ops.expand_dims(labels, 1) # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) - labels = math_ops.reduce_sum( - input_tensor=target_one_hot, reduction_indices=[1]) + labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) labels = math_ops.to_float(labels) # Calculate softmax probabilities for each class. diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 242c1e8ba45e0b2f6f9a1a51695b824546382666..5418e2605b724edb60878e250d2c50fcc6ff5633 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -46,6 +46,10 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self._maybe_initialize_checkpointable() self._name_counts = {} + @property + def _values(self): + return [dep.ref for dep in self._checkpoint_dependencies] + def track(self, checkpointable, base_name): """Add a dependency on `checkpointable`. diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 9e1867ea9d0c72596f5cc848b25331d79fa84c24..f944b7f88438ff257a44581170ead16640540e69 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -21,173 +21,25 @@ py_library( py_library( name = "cluster_resolver_py", - srcs = [ + srcs = glob([ "__init__.py", - "python/training/__init__.py", - ], + "python/training/*.py", + ]), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - ":base_cluster_resolver_py", - ":gce_cluster_resolver_py", - ":kubernetes_cluster_resolver_py", - ":slurm_cluster_resolver_py", - ":tfconfig_cluster_resolver_py", - ":tpu_cluster_resolver_py", - "//tensorflow/python:util", - ], -) - -py_library( - name = "base_cluster_resolver_py", - srcs = ["python/training/cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:training", - ], -) - -py_library( - name = "gce_cluster_resolver_py", - srcs = ["python/training/gce_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "tfconfig_cluster_resolver_py", - srcs = ["python/training/tfconfig_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "tpu_cluster_resolver_py", - srcs = ["python/training/tpu_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "slurm_cluster_resolver_py", - srcs = ["python/training/slurm_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "kubernetes_cluster_resolver_py", - srcs = ["python/training/kubernetes_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -tf_py_test( - name = "base_cluster_resolver_py_test", - srcs = ["python/training/cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/cluster_resolver_test.py", -) - -tf_py_test( - name = "gce_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/gce_cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - ":gce_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/gce_cluster_resolver_test.py", -) - -tf_py_test( - name = "tfconfig_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/tfconfig_cluster_resolver_test.py"], - additional_deps = [ - ":tfconfig_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - grpc_enabled = True, - main = "python/training/tfconfig_cluster_resolver_test.py", -) - -tf_py_test( - name = "tpu_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/tpu_cluster_resolver_test.py"], - additional_deps = [ - ":tpu_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - grpc_enabled = True, - main = "python/training/tpu_cluster_resolver_test.py", -) - -tf_py_test( - name = "slurm_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/slurm_cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - ":slurm_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/slurm_cluster_resolver_test.py", - tags = [], + deps = ["//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib"], ) tf_py_test( - name = "kubernetes_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/kubernetes_cluster_resolver_test.py"], + name = "cluster_resolver_initialization_test", + srcs = ["cluster_resolver_initialization_test.py"], additional_deps = [ ":cluster_resolver_py", - ":kubernetes_cluster_resolver_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], - main = "python/training/kubernetes_cluster_resolver_test.py", + main = "cluster_resolver_initialization_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index fd1263fe81ae826d5edfa8752460fb78fe52b32a..390b3e7550b3d991269bb84707c3500f2fa33290 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -20,12 +20,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -35,6 +37,8 @@ _allowed_symbols = [ 'SimpleClusterResolver', 'UnionClusterResolver', 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', 'TPUClusterResolver', 'SlurmClusterResolver', ] diff --git a/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..01ff1478c694cf0901aeed48b6e0f873d8abe65e --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests to ensure ClusterResolvers are usable via the old contrib path.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training import cluster_resolver +from tensorflow.contrib.cluster_resolver.python.training import UnionClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +class ClusterResolverInitializationTest(test.TestCase): + + def testCreateSimpleClusterResolverFromLib(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + cluster_resolver.SimpleClusterResolver(base_cluster_spec) + + def testCreateSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + SimpleClusterResolver(base_cluster_spec) + + def testCreateUnionClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + simple_cr = SimpleClusterResolver(base_cluster_spec) + UnionClusterResolver(simple_cr) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 6d9120a3b96e1960a438772e282ef653b364b7eb..10d93549ebbd4f7e900796d0516b0af1744224af 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -18,11 +18,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'cluster_resolver', + 'gce_cluster_resolver', + 'kubernetes_cluster_resolver', + 'slurm_cluster_resolver', + 'tfconfig_cluster_resolver', + 'tpu_cluster_resolver', + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', + 'TPUClusterResolver', + 'SlurmClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 40b1e667ee6039b44b1a442d41dc28dfcbad6dc6..99840fb5166dd739b3bee06a926e06b534011d1f 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,333 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" +"""Stub file for ClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -import six +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +# pylint: enable=unused-import -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', +] -def _format_master_url(master, rpc_layer=None): - if rpc_layer: - return '%s://%s' % (rpc_layer, master) - else: - return master +remove_undocumented(__name__, _allowed_symbols) - -@six.add_metaclass(abc.ABCMeta) -class ClusterResolver(object): - """Abstract class for all implementations of ClusterResolvers. - - This defines the skeleton for all implementations of ClusterResolvers. - ClusterResolvers are a way for TensorFlow to communicate with various cluster - management systems (e.g. GCE, AWS, etc...). - - By letting TensorFlow communicate with these systems, we will be able to - automatically discover and resolve IP addresses for various TensorFlow - workers. This will eventually allow us to automatically recover from - underlying machine failures and scale TensorFlow worker clusters up and down. - """ - - @abc.abstractmethod - def cluster_spec(self): - """Retrieve the current state of the cluster and returns a ClusterSpec. - - Returns: - A ClusterSpec representing the state of the cluster at the moment this - function is called. - - Implementors of this function must take care in ensuring that the - ClusterSpec returned is up-to-date at the time of calling this function. - This usually means retrieving the information from the underlying cluster - management system every time this function is invoked and reconstructing - a cluster_spec, rather than attempting to cache anything. - """ - raise NotImplementedError( - 'cluster_spec is not implemented for {}.'.format(self)) - - @abc.abstractmethod - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Retrieves the name or URL of the session master. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC protocol for the given cluster. - - Returns: - The name or URL of the session master. - - Implementors of this function must take care in ensuring that the master - returned is up-to-date at the time to calling this function. This usually - means retrieving the master every time this function is invoked. - """ - raise NotImplementedError('master is not implemented for {}.'.format(self)) - - -class SimpleClusterResolver(ClusterResolver): - """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - - def __init__(self, cluster_spec, master='', task_type=None, task_index=None, - environment='', num_accelerators_per_worker=0, - rpc_layer=None): - """Creates a SimpleClusterResolver from a ClusterSpec.""" - super(SimpleClusterResolver, self).__init__() - - self._task_type = task_type - self._task_index = task_index - self._environment = environment - self._num_accelerators_per_worker = num_accelerators_per_worker - self._rpc_layer = rpc_layer - - if not isinstance(cluster_spec, ClusterSpec): - raise TypeError('cluster_spec must be a ClusterSpec.') - self._cluster_spec = cluster_spec - - if not isinstance(master, str): - raise TypeError('master must be a string.') - self._master = master - - def cluster_spec(self): - """Returns the ClusterSpec passed into the constructor.""" - return self._cluster_spec - - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Returns the master address to use when creating a session. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC used by distributed TensorFlow. - - Returns: - The name or URL of the session master. - - If a task_type and task_index is given, this will override the `master` - string passed into the initialization function. - """ - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - else: - master = self._master - - return _format_master_url(master, rpc_layer or self._rpc_layer) - - @property - def task_type(self): - return self._task_type - - @property - def task_index(self): - return self._task_index - - @task_type.setter - def task_type(self, task_type): - self._task_type = task_type - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - return self._environment - - def num_accelerators_per_worker(self, session_config=None): - """Returns the number of accelerator cores per worker. - - Args: - session_config: Unused. The SimpleClusterResolver does not do automatic - detection of accelerators, so a TensorFlow session will never be - created, and thus a `session_config` is never necessary here, and will - be ignored. - """ - del session_config - return self._num_accelerators_per_worker - - @property - def rpc_layer(self): - return self._rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer - - -class UnionClusterResolver(ClusterResolver): - """Performs a union on underlying ClusterResolvers. - - This class performs a union given two or more existing ClusterResolvers. It - merges the underlying ClusterResolvers, and returns one unified ClusterSpec - when cluster_spec is called. The details of the merge function is - documented in the cluster_spec function. - - For additional Cluster Resolver properties such as task type, task index, - rpc layer, environment, etc..., we will return the value from the first - ClusterResolver in the union. - """ - - def __init__(self, *args, **kwargs): - """Initializes a UnionClusterResolver with other ClusterResolvers. - - Args: - *args: `ClusterResolver` objects to be unionized. - **kwargs: - rpc_layer - (Optional) Override value for the RPC layer used by - TensorFlow. - task_type - (Optional) Override value for the current task type. - task_index - (Optional) Override value for the current task index. - - Raises: - TypeError: If any argument is not a subclass of `ClusterResolvers`. - ValueError: If there are no arguments passed. - """ - super(UnionClusterResolver, self).__init__() - - self._rpc_layer = kwargs.pop('rpc_layer', None) - self._task_type = kwargs.pop('task_type', None) - self._task_index = kwargs.pop('task_index', None) - - if kwargs: - raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) - - if not args: - raise ValueError('At least one ClusterResolver is required.') - - for cluster_resolver in args: - if not isinstance(cluster_resolver, ClusterResolver): - raise TypeError('All arguments must be a sub-class of ' - '`ClusterResolver.`') - self._cluster_resolvers = args - - def cluster_spec(self): - """Returns a union of all the ClusterSpecs from the ClusterResolvers. - - Returns: - A ClusterSpec containing host information merged from all the underlying - ClusterResolvers. - - Raises: - KeyError: If there are conflicting keys detected when merging two or - more dictionaries, this exception is raised. - - Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the - same job name, we will merge the list/dict of workers. - - If *all* underlying ClusterSpecs expose the set of workers as lists, we will - concatenate the lists of workers, starting with the list of workers from - the first ClusterResolver passed into the constructor. - - If *any* of the ClusterSpecs expose the set of workers as a dict, we will - treat all the sets of workers as dicts (even if they are returned as lists) - and will only merge them into a dict if there is no conflicting keys. If - there is a conflicting key, we will raise a `KeyError`. - """ - - merged_cluster = {} - - # We figure out whether it is all lists for a particular job, or whether - # there are dicts inside. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if job_name in merged_cluster: - # If we see a dict, then we write a dict out regardless. - if isinstance(tasks, dict): - merged_cluster[job_name] = {} - else: - # We take whichever type is present. - if isinstance(tasks, list): - merged_cluster[job_name] = [] - else: - merged_cluster[job_name] = {} - - # We then do the merge as appropriate in merged_cluster[job]. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if isinstance(merged_cluster[job_name], list): - # We all have lists, we can just concatenate and be done. - merged_cluster[job_name].extend(tasks) - else: - if isinstance(tasks, list): - # We convert to a dictionary if the type is a list. - task_dict = dict(zip(range(0, len(tasks)), tasks)) - else: - # We can simply make a copy (for update) and be done. - task_dict = tasks.copy() - - # We detect if there are duplicates, and raise an error if so. - task_keys = set(task_dict) - merged_keys = set(merged_cluster[job_name].keys()) - intersected_keys = task_keys.intersection(merged_keys) - if intersected_keys: - raise KeyError('Duplicate keys detected when merging two ' - 'ClusterSpecs: %s' % repr(intersected_keys)) - - # We do the merge after all the processing. - merged_cluster[job_name].update(task_dict) - - return ClusterSpec(merged_cluster) - - def master(self, task_type=None, task_index=None, rpc_layer=None): - """Returns the master address to use when creating a session. - - This usually returns the master from the first ClusterResolver passed in, - but you can override this by specifying the task_type and task_index. - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - rpc_layer: (Optional) The RPC protocol for the given cluster. - - Returns: - The name or URL of the session master. - """ - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - return _format_master_url(master, rpc_layer or self._rpc_layer) - - return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) - - @property - def task_type(self): - return self._task_type or self._cluster_resolvers[0].task_type - - @property - def task_index(self): - return self._task_index or self._cluster_resolvers[0].task_index - - @task_type.setter - def task_type(self, task_type): - self._task_type = task_type - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - return self._cluster_resolvers[0].environment - - def num_accelerators_per_worker(self, session_config=None): - return self._cluster_resolvers[0].num_accelerators_per_worker( - session_config) - - @property - def rpc_layer(self): - return self._rpc_layer or self._cluster_resolvers[0].rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 195b68959b6d21ef674438a4a23a4dd07f45faa7..55e61155c683c928efab9bb018868faec3e3df8c 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,197 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for GCE Instance Groups.""" +"""Stub file for GceClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +# pylint: enable=unused-import -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'GceClusterResolver', +] -def _format_master_url(master, rpc_layer=None): - return '%s://%s' % (rpc_layer, master) if rpc_layer else master - - -class GceClusterResolver(ClusterResolver): - """Cluster Resolver for Google Compute Engine. - - This is an implementation of cluster resolvers for the Google Compute Engine - instance group platform. By specifying a project, zone, and instance group, - this will retrieve the IP address of all the instances within the instance - group and return a Cluster Resolver object suitable for use for distributed - TensorFlow. - """ - - def __init__(self, - project, - zone, - instance_group, - port, - task_type='worker', - task_index=0, - rpc_layer='grpc', - num_accelerators_per_worker=0, - credentials='default', - service=None): - """Creates a new GceClusterResolver object. - - This takes in a few parameters and creates a GceClusterResolver project. It - will then use these parameters to query the GCE API for the IP addresses of - each instance in the instance group. - - Args: - project: Name of the GCE project. - zone: Zone of the GCE instance group. - instance_group: Name of the GCE instance group. - port: Port of the listening TensorFlow server (default: 8470) - task_type: Name of the TensorFlow job this GCE instance group of VM - instances belong to. - task_index: The task index for this particular VM, within the GCE - instance group. In particular, every single instance should be assigned - a unique ordinal index within an instance group manually so that they - can be distinguished from each other. - rpc_layer: The RPC layer TensorFlow should use to communicate across - instances. - num_accelerators_per_worker: Number of accelerators (GPUs) present per - instance. - credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default(). - service: The GCE API object returned by the googleapiclient.discovery - function. (Default: discovery.build('compute', 'v1')). If you specify a - custom service object, then the credentials parameter will be ignored. - - Raises: - ImportError: If the googleapiclient is not installed. - """ - self._project = project - self._zone = zone - self._instance_group = instance_group - self._task_type = task_type - self._task_index = task_index - self._rpc_layer = rpc_layer - self._port = port - self._credentials = credentials - - if credentials == 'default': - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'GCE cluster resolver') - self._service = discovery.build( - 'compute', 'v1', - credentials=self._credentials) - else: - self._service = service - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest instance group info. - - This returns a ClusterSpec object for use based on information from the - specified instance group. We will retrieve the information from the GCE APIs - every time this method is called. - - Returns: - A ClusterSpec containing host information retrieved from GCE. - """ - request_body = {'instanceState': 'RUNNING'} - request = self._service.instanceGroups().listInstances( - project=self._project, - zone=self._zone, - instanceGroups=self._instance_group, - body=request_body, - orderBy='name') - - worker_list = [] - - while request is not None: - response = request.execute() - - items = response['items'] - for instance in items: - instance_name = instance['instance'].split('/')[-1] - - instance_request = self._service.instances().get( - project=self._project, - zone=self._zone, - instance=instance_name) - - if instance_request is not None: - instance_details = instance_request.execute() - ip_address = instance_details['networkInterfaces'][0]['networkIP'] - instance_url = '%s:%s' % (ip_address, self._port) - worker_list.append(instance_url) - - request = self._service.instanceGroups().listInstances_next( - previous_request=request, - previous_response=response) - - worker_list.sort() - return ClusterSpec({self._task_type: worker_list}) - - def master(self, task_type=None, task_index=None, rpc_layer=None): - task_type = task_type if task_type is not None else self._task_type - task_index = task_index if task_index is not None else self._task_index - - if task_type is not None and task_index is not None: - master = self.cluster_spec().task_address(task_type, task_index) - if rpc_layer or self._rpc_layer: - return '%s://%s' % (rpc_layer or self._rpc_layer, master) - else: - return master - - return '' - - @property - def task_type(self): - return self._task_type - - @property - def task_index(self): - return self._task_index - - @task_type.setter - def task_type(self, task_type): - raise RuntimeError( - 'You cannot reset the task_type of the GceClusterResolver after it has ' - 'been created.') - - @task_index.setter - def task_index(self, task_index): - self._task_index = task_index - - @property - def environment(self): - """Returns the current environment which TensorFlow is running in. - - For users in the GCE environment, the environment property is always an - empty string, and Google users will not use this ClusterResolver for running - on internal systems. - """ - return '' - - @property - def rpc_layer(self): - return self._rpc_layer - - @rpc_layer.setter - def rpc_layer(self, rpc_layer): - self._rpc_layer = rpc_layer - - def num_accelerators_per_worker(self, session_config=None): - del session_config # Unused, since this is set manually in __init__. - return self._num_accelerators_per_worker - +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py index ddae64839f01b4f67fe4c0c0bc00199bb2e037aa..a8eaf33629a6299d5da5f8a930e0cad7d07044e8 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py @@ -12,121 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Kubernetes.""" +"""Stub file for KubernetesClusterResolver for backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training import server_lib +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -_KUBERNETES_API_CLIENT_INSTALLED = True -try: - from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top - from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top -except ImportError: - _KUBERNETES_API_CLIENT_INSTALLED = False +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -class KubernetesClusterResolver(ClusterResolver): - """Cluster Resolver for Kubernetes. +_allowed_symbols = [ + 'KubernetesClusterResolver', +] - This is an implementation of cluster resolvers for Kubernetes. When given the - the Kubernetes namespace and label selector for pods, we will retrieve the - pod IP addresses of all running pods matching the selector, and return a - ClusterSpec based on that information. - """ +remove_undocumented(__name__, _allowed_symbols) - def __init__(self, - job_to_label_mapping=None, - tf_server_port=8470, - override_client=None): - """Initializes a new KubernetesClusterResolver. - - This initializes a new Kubernetes Cluster Resolver. The Cluster Resolver - will attempt to talk to the Kubernetes master to retrieve all the instances - of pods matching a label selector. - - Args: - job_to_label_mapping: A mapping of TensorFlow jobs to label selectors. - This allows users to specify many TensorFlow jobs in one Cluster - Resolver, and each job can have pods belong with different label - selectors. For example, a sample mapping might be - ``` - {'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'], - 'ps': ['job-name=ps-1', 'job-name=ps-2']} - ``` - tf_server_port: The port the TensorFlow server is listening on. - override_client: The Kubernetes client (usually automatically retrieved - using `from kubernetes import client as k8sclient`). If you pass this - in, you are responsible for setting Kubernetes credentials manually. - - Raises: - ImportError: If the Kubernetes Python client is not installed and no - `override_client` is passed in. - """ - if _KUBERNETES_API_CLIENT_INSTALLED: - k8sconfig.load_kube_config() - - if not job_to_label_mapping: - job_to_label_mapping = {'worker': ['job-name=tensorflow']} - - if not override_client and not _KUBERNETES_API_CLIENT_INSTALLED: - raise ImportError('The Kubernetes Python client must be installed before' - 'using the Kubernetes Cluster Resolver. To install the' - 'Kubernetes Python client, run `pip install ' - 'kubernetes` on your command line.') - - self._job_to_label_mapping = job_to_label_mapping - self._tf_server_port = tf_server_port - self._override_client = override_client - - def master(self): - # TODO(frankchn): Figure out a standard way to pass in the current task type - # and task id via Kubernetes. - pass - - def get_master(self): - return self.master() - - def get_job_name(self): - return self._job_name - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest info from Kubernetes. - - We retrieve the information from the Kubernetes master every time this - method is called. - - Returns: - A ClusterSpec containing host information returned from Kubernetes. - - Raises: - RuntimeError: If any of the pods returned by the master is not in the - `Running` phase. - """ - if not self._override_client: - k8sconfig.load_kube_config() - - client = self._override_client or k8sclient.CoreV1Api() - cluster_map = {} - - for tf_job in self._job_to_label_mapping: - all_pods = [] - for selector in self._job_to_label_mapping[tf_job]: - ret = client.list_pod_for_all_namespaces(label_selector=selector) - selected_pods = [] - - # Sort the list by the name to make sure it doesn't change call to call. - for pod in sorted(ret.items, key=lambda x: x.metadata.name): - if pod.status.phase == 'Running': - selected_pods.append( - '%s:%s' % (pod.status.host_ip, self._tf_server_port)) - else: - raise RuntimeError('Pod "%s" is not running; phase: "%s"' % - (pod.metadata.name, pod.status.phase)) - all_pods.extend(selected_pods) - cluster_map[tf_job] = all_pods - - return server_lib.ClusterSpec(cluster_map) diff --git a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py index dabe2fe1d39db14c60e5437d636144f18c384cf1..fcd2a846eeb1be7ad4b5a98b067a125afbbebc7d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py @@ -12,185 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Slurm workload manager.""" +"""Stub file for SlurmClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import os -import subprocess +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -class SlurmClusterResolver(ClusterResolver): - """Cluster Resolver for system with Slurm workload manager. +_allowed_symbols = [ + 'SlurmClusterResolver', +] - This is an implementation of cluster resolvers for Slurm clusters. This allows - the specification of jobs and task counts, number of tasks per node, number of - GPUs on each node and number of GPUs for each task, It retrieves system - attributes by Slurm environment variables, resolves allocated computing node - names, construct a cluster and return a Cluster Resolver object which an be - use for distributed TensorFlow. - """ - - def _resolve_hostnames(self): - """Resolve host names of nodes allocated in current jobs. - - Returns: - A list of node names as strings. - """ - hostlist = (subprocess.check_output(['scontrol', 'show', 'hostname']). - decode('utf-8').strip().split('\n')) - return hostlist - - def __init__(self, - jobs, - port_base=8888, - gpus_per_node=1, - gpus_per_task=1, - tasks_per_node=None, - auto_set_gpu=True): - """Creates a new SlurmClusterResolver object. - - This takes in parameters and creates a SlurmClusterResolver object. It uses - those parameters to check which nodes will processes reside and resolves - their hostnames. With the number of the GPUs on each node and number of GPUs - for each task it offsets the port number for each processes and allocate - GPUs to tasks by setting environment variables. The resolver currently - supports homogeneous tasks and default Slurm process allocation. - - Args: - jobs: Dictionary with job names as key and number of tasks in the job as - value - port_base: The first port number to start with for processes on a node. - gpus_per_node: Number of GPUs available on each node. - gpus_per_task: Number of GPUs to be used for each task. - tasks_per_node: Number of tasks to run on each node, if not set defaults - to Slurm's output environment variable SLURM_NTASKS_PER_NODE. - auto_set_gpu: Set the visible CUDA devices automatically while resolving - the cluster by setting CUDA_VISIBLE_DEVICES environment variable. - Defaults to True. - - Returns: - A ClusterResolver object which can be used with distributed TensorFlow. - - Raises: - RuntimeError: If requested more GPUs per node then available or requested - more tasks then assigned tasks. - """ - - # check if launched by mpirun - if 'OMPI_COMM_WORLD_RANK' in os.environ: - self._rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - num_tasks = int(os.environ['OMPI_COMM_WORLD_SIZE']) - else: - self._rank = int(os.environ['SLURM_PROCID']) - num_tasks = int(os.environ['SLURM_NTASKS']) - - self._jobs = collections.OrderedDict(sorted(jobs.items())) - self._port_base = port_base - - # user specification overrides SLURM specification - if tasks_per_node is not None: - self._tasks_per_node = tasks_per_node - elif tasks_per_node is None and 'SLURM_NTASKS_PER_NODE' in os.environ: - self._tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE']) - else: - raise RuntimeError('Neither `tasks_per_node` or ' - 'SLURM_NTASKS_PER_NODE is set.') - - self._gpus_per_node = gpus_per_node - self._gpus_per_task = gpus_per_task - - self._auto_set_gpu = auto_set_gpu - self._job_name = None - self._task_index = None - - self._gpu_allocation = [] - self._cluster_allocation = {} - - if self._tasks_per_node * self._gpus_per_task > self._gpus_per_node: - raise RuntimeError('Requested more GPUs per node then available.') - - if sum(self._jobs.values()) != num_tasks: - raise RuntimeError('Requested more tasks then assigned tasks.') - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest instance group info. - - This returns a ClusterSpec object for use based on information from the - specified initialization parameters and Slurm environment variables. The - cluster specification is resolved each time this function is called. The - resolver extract hostnames of nodes by scontrol and pack tasks in that - order until a node a has number of tasks that is equal to specification. - GPUs on nodes are allocated to tasks by specification through setting - CUDA_VISIBLE_DEVICES environment variable. - - Returns: - A ClusterSpec containing host information retrieved from Slurm's - environment variables. - """ - hostlist = self._resolve_hostnames() - - task_list = [] - self._gpu_allocation = [] - self._cluster_allocation = {} - - for host in hostlist: - for port_offset, gpu_offset in zip( - range(self._tasks_per_node), - range(0, self._gpus_per_node, self._gpus_per_task)): - - host_addr = '%s:%d' % (host, self._port_base + port_offset) - task_list.append(host_addr) - gpu_id_list = [] - - for gpu_id in range(gpu_offset, gpu_offset + self._gpus_per_task): - gpu_id_list.append(str(gpu_id)) - - self._gpu_allocation.append(','.join(gpu_id_list)) - - cluster_rank_offset_start = 0 - cluster_rank_offset_end = 0 - - for job_name, num_tasks in self._jobs.items(): - cluster_rank_offset_end = cluster_rank_offset_start + num_tasks - - self._cluster_allocation[job_name] = \ - task_list[cluster_rank_offset_start:cluster_rank_offset_end] - - if self._rank >= cluster_rank_offset_start and \ - self._rank < cluster_rank_offset_end: - - self._job_name = job_name - self._task_index = self._rank - cluster_rank_offset_start - - cluster_rank_offset_start = cluster_rank_offset_end - - if self._auto_set_gpu is True: - os.environ['CUDA_VISIBLE_DEVICES'] = self._gpu_allocation[self._rank] - - return ClusterSpec(self._cluster_allocation) - - def get_task_info(self): - """Returns job name and task_index for the process which calls this. - - This returns the job name and task index for the process which calls this - function according to its rank and cluster specification. The job name and - task index are set after a cluster is constructed by cluster_spec otherwise - defaults to None. - - Returns: - A string specifying job name the process belongs to and an integner - specifying the task index the process belongs to in that job. - """ - return self._job_name, self._task_index - - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) - return self._cluster_allocation[str(self._job_name)][self._task_index] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py index 7bbd189d03d9c96914d11948941916739f10d18f..9db7f47dcb49c499719b9002b1d2d6c4837a7bd2 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py @@ -12,81 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables.""" - +"""Stub file for TFConfigClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json -import os - -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec - -_TF_CONFIG_ENV = 'TF_CONFIG' -_SESSION_MASTER_KEY = 'session_master' - - -class TFConfigClusterResolver(ClusterResolver): - """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar.""" - - def _load_tf_config(self): - return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) - - def cluster_spec(self): - """Returns a ClusterSpec based on the TF_CONFIG environment variable. - - Returns: - A ClusterSpec with information from the TF_CONFIG environment variable. - """ - tf_config = self._load_tf_config() - if 'cluster' not in tf_config: - return ClusterSpec({}) - return ClusterSpec(tf_config['cluster']) - - def master(self, task_type=None, task_index=0): - """Returns the master address to use when creating a TensorFlow session. - - Args: - task_type: (String, optional) Overrides and sets the task_type of the - master. - task_index: (Integer, optional) Overrides and sets the task id of the - master. - - Returns: - The address of the master. - - Raises: - RuntimeError: If the task_type or task_id is not specified and the - `TF_CONFIG` environment variable does not contain a task section. - """ +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. - # If `session_master` is set, just use that. - tf_config = self._load_tf_config() - if _SESSION_MASTER_KEY in tf_config: - return tf_config[_SESSION_MASTER_KEY] +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +# pylint: enable=unused-import - if 'rpc_layer' in tf_config: - rpclayer = '%s://' % tf_config['rpc_layer'] - else: - rpclayer = '' +from tensorflow.python.util.all_util import remove_undocumented - # Return an empty string if we are the only job in the ClusterSpec. - cluster_spec = self.cluster_spec() - if (not cluster_spec.jobs or - (len(cluster_spec.jobs) == 1 and - len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)): - return '' +_allowed_symbols = [ + 'TFConfigClusterResolver', +] - # We try to auto-detect the task type and id, but uses the user-supplied one - # where available - if not task_type: - if 'task' not in tf_config: - raise RuntimeError('You must either specify a `task_type`, or your ' - 'TF_CONFIG must contain a `task` section.') - task_type = tf_config['task']['type'] - task_index = tf_config['task']['index'] +remove_undocumented(__name__, _allowed_symbols) - return rpclayer + cluster_spec.task_address(task_type, task_index) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1f6803a9ff9a7a1e72ee691afd7e22bb4d85475c..3a1eaccd06e574babbe9a3232dacd1d66f3a4648 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,341 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Cloud TPUs.""" +"""Stub file for TPUClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from six.moves.urllib.request import Request -from six.moves.urllib.request import urlopen +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver +# pylint: enable=unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training import server_lib -from tensorflow.python.util import compat +from tensorflow.python.util.all_util import remove_undocumented -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +_allowed_symbols = [ + 'TPUClusterResolver', +] - -_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' -_ENDPOINTS_SEPARATOR = ',' -_DEFAULT_ENV_VARIABLE = 'TPU_NAME' -_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' - - -class TPUClusterResolver(ClusterResolver): - """Cluster Resolver for Google Cloud TPUs. - - This is an implementation of cluster resolvers for the Google Cloud TPU - service. As Cloud TPUs are in alpha, you will need to specify a API definition - file for this to consume, in addition to a list of Cloud TPUs in your Google - Cloud Platform project. - """ - - def _tpuService(self): - """Creates a new Cloud TPU API object. - - This works around an issue where the underlying HTTP connection sometimes - times out when the script has been running for too long. Other methods in - this object calls this method to get a new API object whenever they need - to communicate with the Cloud API. - - Returns: - A Google Cloud TPU API object. - """ - if self._service: - return self._service - - credentials = self._credentials - if credentials is None or credentials == 'default': - credentials = GoogleCredentials.get_application_default() - - if self._discovery_url: - return discovery.build( - 'tpu', 'v1alpha1', - credentials=credentials, - discoveryServiceUrl=self._discovery_url) - else: - return discovery.build( - 'tpu', 'v1alpha1', - credentials=credentials) - - def _requestComputeMetadata(self, path): - req = Request('http://metadata/computeMetadata/v1/%s' % path, - headers={'Metadata-Flavor': 'Google'}) - resp = urlopen(req) - return compat.as_bytes(resp.read()) - - def _shouldResolve(self): - if (self._tpu == compat.as_bytes('') or - self._tpu == compat.as_bytes('local') or - self._tpu.startswith(compat.as_bytes('/bns')) or - self._tpu.startswith(compat.as_bytes('localhost:')) or - self._tpu.startswith(compat.as_bytes('grpc://'))): - return False - return True - - @staticmethod - def _inGke(): - """When running in GKE, the environment variable will be set.""" - return _GKE_ENV_VARIABLE in os.environ - - @staticmethod - def _gkeEndpoints(): - return os.environ[_GKE_ENV_VARIABLE] - - @staticmethod - def _envVarFallback(): - if _DEFAULT_ENV_VARIABLE in os.environ: - return os.environ[_DEFAULT_ENV_VARIABLE] - return None - - @staticmethod - def _environmentDiscoveryUrl(): - return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) - - def __init__(self, - tpu=None, - zone=None, - project=None, - job_name='worker', - coordinator_name=None, - coordinator_address=None, - credentials='default', - service=None, - discovery_url=None): - """Creates a new TPUClusterResolver object. - - The ClusterResolver will then use the parameters to query the Cloud TPU APIs - for the IP addresses and ports of each Cloud TPU listed. - - Args: - tpu: Either a string, or a list of strings corresponding to the TPUs to - use. If the single string is the empty string, the string 'local', or a - string that begins with 'grpc://' or '/bns', then it is assumed to not - correspond with a Cloud TPU and will instead be passed as the session - master and no ClusterSpec propagation will be done. - zone: Zone where the TPUs are located. If omitted or empty, we will assume - that the zone of the TPU is the same as the zone of the GCE VM, which we - will try to discover from the GCE metadata service. - project: Name of the GCP project containing Cloud TPUs. If omitted or - empty, we will try to discover the project name of the GCE VM from the - GCE metadata service. - job_name: Name of the TensorFlow job the TPUs belong to. - coordinator_name: The name to use for the coordinator. Set to None if the - coordinator should not be included in the computed ClusterSpec. - coordinator_address: The address of the coordinator (typically an ip:port - pair). If set to None, a TF server will be started. If coordinator_name - is None, a TF server will not be started even if coordinator_address is - None. - credentials: GCE Credentials. If None, then we use default credentials - from the oauth2client - service: The GCE API object returned by the googleapiclient.discovery - function. If you specify a custom service object, then the credentials - parameter will be ignored. - discovery_url: A URL template that points to the location of - the discovery service. It should have two parameters {api} and - {apiVersion} that when filled in produce an absolute URL to the - discovery document for that service. The environment variable - 'TPU_API_DISCOVERY_URL' will override this. - - Raises: - ImportError: If the googleapiclient is not installed. - ValueError: If no TPUs are specified. - """ - if isinstance(tpu, list): - if not tpu: - raise ValueError('At least one TPU must be specified.') - if len(tpu) != 1: - raise NotImplementedError( - 'Using multiple TPUs in a single session is not yet implemented') - tpu = tpu[0] - - in_gke = self._inGke() - # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None: - if in_gke: - tpu = self._gkeEndpoints() - else: - tpu = self._envVarFallback() - - if tpu is None: - raise ValueError('Please provide a TPU Name to connect to.') - - self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes - self._job_name = job_name - - # Whether we should actually attempt to contact Cloud APIs - should_resolve = self._shouldResolve() - - # We error out if we are in a non-Cloud environment which cannot talk to the - # Cloud APIs using the standard class and a special object is not passed in. - self._service = service - if (self._service is None and should_resolve and - not _GOOGLE_API_CLIENT_INSTALLED): - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - # We save user-passed credentials, unless the user didn't pass in anything. - self._credentials = credentials - if (credentials == 'default' and should_resolve and - _GOOGLE_API_CLIENT_INSTALLED): - self._credentials = None - - # Automatically detect project and zone if unspecified. - if not project and should_resolve: - project = compat.as_str( - self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: - zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) - zone = zone_path.split('/')[-1] - self._project = project - self._zone = zone - - self._discovery_url = self._environmentDiscoveryUrl() or discovery_url - - self._coordinator_name = coordinator_name - if (coordinator_name and not coordinator_address and - (should_resolve or in_gke)): - self._start_local_server() - else: - self._coordinator_address = coordinator_address - - def master(self, task_type=None, task_index=None): - """Get the Master string to be used for the session. - - In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of - first instance in the ClusterSpec returned by the cluster_spec function. - - If a non-TPU name is used when constructing a TPUClusterResolver, that will - be returned instead (e.g. If the tpus argument's value when constructing - this TPUClusterResolver was 'grpc://10.240.1.2:8470', - 'grpc://10.240.1.2:8470' will be returned). - - Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. - - Returns: - string, the connection string to use when creating a session. - - Raises: - ValueError: If none of the TPUs specified exists. - """ - if not self._shouldResolve(): - return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] - - cluster_spec = self.cluster_spec() - if task_type and task_index: - return cluster_spec.task_address(task_type, task_index) - - job_tasks = cluster_spec.job_tasks(self._job_name) - if not job_tasks: - raise ValueError('No TPUs exists with the specified names exist.') - - return 'grpc://' + job_tasks[0] - - def get_master(self): - return self.master() - - def get_job_name(self): - if self._shouldResolve(): - return self._job_name - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest TPU information. - - We retrieve the information from the GCE APIs every time this method is - called. - - Returns: - A ClusterSpec containing host information returned from Cloud TPUs. - - Raises: - RuntimeError: If the provided TPU is not healthy. - """ - ############################################################################ - # There are 5 potential cases this code must handle: - # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and - # a. Create a ClusterSpec that includes the coordinator job - # b. Create a ClusterSpec without the coordinator job. - # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of - # tasks and - # a. Create a ClusterSpec with the coordinator - # b. Create a ClusterSpec without the coordinator - # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. - ############################################################################ - - if self._shouldResolve(): - # Case 1. - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, compat.as_text(self._tpu)) - service = self._tpuService() - request = service.projects().locations().nodes().get(name=full_name) - response = request.execute() - - if 'state' in response and response['state'] != 'READY': - raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (compat.as_text(self._tpu), response['state'])) - - if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % - (compat.as_text(self._tpu), response['health'])) - - if 'networkEndpoints' in response: - worker_list = [ - '%s:%s' % (endpoint['ipAddress'], endpoint['port']) - for endpoint in response['networkEndpoints'] - ] - else: - # Fall back to the deprecated response format - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list = [instance_url] - - cluster_spec = {self._job_name: worker_list} - else: - if not self._tpu.startswith(compat.as_bytes('grpc://')): - # Case 3. - return None - # Case 2. - cluster_spec = { - self._job_name: [ - x[len(compat.as_bytes('grpc://')):] - for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) - ] - } - - if self._coordinator_address: - # {1, 2}.a - cluster_spec[self._coordinator_name] = [self._coordinator_address] - - return server_lib.ClusterSpec(cluster_spec) - - def _start_local_server(self): - address = self._requestComputeMetadata('instance/network-interfaces/0/ip') - self._server = server_lib.Server( - { - 'local': ['0.0.0.0:0'] - }, protocol='grpc', config=None, start=True) - # self._server.target is of the form: grpc://ipaddress:port - target = compat.as_bytes(self._server.target) - splits = target.split(compat.as_bytes(':')) - assert len(splits) == 3, self._server.target - assert splits[0] == compat.as_bytes('grpc'), self._server.target - self._coordinator_port = compat.as_text(splits[2]) - self._coordinator_address = '%s:%s' % ( - address, compat.as_text(self._coordinator_port)) - - def __deepcopy__(self, memo): - # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. - return self +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index fca24b16043524c7651c7b7a3a83cac1bfdd53fb..df8b48dfc46124d3b9454d92ffb70dbcf1bc4217 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -5,10 +5,10 @@ CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all platforms. For details, see the [TensorFlow install guide](https://www.tensorflow.org/install/). -This directory contains CMake files for building TensorFlow on Microsoft -Windows and Linux. [CMake](https://cmake.org) is a cross-platform tool that can -generate build scripts for multiple build systems, including Microsoft -Visual Studio and GCC. "The method has not been tested on Mac OS X. +This directory contains CMake files for building TensorFlow on Microsoft Windows +and Linux. [CMake](https://cmake.org) is a cross-platform tool that can generate +build scripts for multiple build systems, including Microsoft Visual Studio and +GCC. "The method has not been tested on Mac OS X. **N.B.** We provide Linux build instructions primarily for the purpose of testing the build. We recommend using the standard Bazel-based build on @@ -17,13 +17,17 @@ Linux. Current Status -------------- -CMake can be used to build TensorFlow on all platforms. See the [getting started documentation](https://www.tensorflow.org/install/install_windows) -for instructions on how to install a pre-built TensorFlow package on Windows and Linux. The procedure in MacOS is similar to the Linux build. +CMake can be used to build TensorFlow on all platforms. See the +[getting started documentation](https://www.tensorflow.org/install/install_windows) +for instructions on how to install a pre-built TensorFlow package on Windows and +Linux. The procedure in MacOS is similar to the Linux build. ### Current known limitations -* It is not possible to load a custom Op library. -* GCS file system is not supported. -* Debug build is not available since Python for Windows is no longer distributed with a debug library. + +* It is not possible to load a custom Op library. +* GCS file system is not supported. +* Debug build is not available since Python for Windows is no longer + distributed with a debug library. ## Building with CMake @@ -33,77 +37,88 @@ bindings. ### Prerequisites -* CMake version 3.5 or later. +* CMake version 3.5 or later. + +* [Git](https://git-scm.com) + +* [SWIG](http://www.swig.org/download.html) + +* [Perl](https://www.perl.org/get.html) (optional, for SSL support build) + +* [Go](https://golang.org/) (optional, for SSL support build) -* [Git](https://git-scm.com) +* [NASM](http://www.nasm.us/)/[YASM](http://yasm.tortall.net/) (optional, for + SSL support build) -* [SWIG](http://www.swig.org/download.html) +* Additional pre-requisites for Microsoft Windows: -* [Perl](https://www.perl.org/get.html) (optional, for SSL support build) + - Visual Studio 2015 (latest version of MSVC 2017 is not supported by CUDA + yet, try it on your own risk) -* [Go](https://golang.org/) (optional, for SSL support build) + - Python 3.5 -* [NASM](http://www.nasm.us/)/[YASM](http://yasm.tortall.net/) (optional, for SSL support build) +* Additional prerequisites for Linux: -* Additional pre-requisites for Microsoft Windows: - - Visual Studio 2015 (latest version of MSVC 2017 is not supported by CUDA yet, try it on your own risk) - - - Python 3.5 + - Python 2.7 or later + - [Docker](https://www.docker.com/) (for automated testing) -* Additional prerequisites for Linux: - - Python 2.7 or later - - [Docker](https://www.docker.com/) (for automated testing) +* Python dependencies: -* Python dependencies: - - wheel - - NumPy 1.11.0 or later + - wheel + - NumPy 1.11.0 or later ### Known-good configurations -* Microsoft Windows 10 - - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - - [swigwin-3.0.10](http://www.swig.org/download.html) - - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) - - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) - - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) +* Microsoft Windows 10 -* Ubuntu 14.04 - - Makefile generator - - Docker 1.9.1 (for automated testing) + - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) + - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) + - [swigwin-3.0.10](http://www.swig.org/download.html) + - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) + - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) + - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) + +* Ubuntu 14.04 + + - Makefile generator + - Docker 1.9.1 (for automated testing) ### Current known limitations - - The Python package supports **Python 3.5/3.6 only**, because these are the only - versions for which standard Python binaries exist and those binaries are - compatible with the TensorFlow runtime. (On Windows, the standard Python + +- The Python package supports **Python 3.5/3.6 only**, because these are the + only versions for which standard Python binaries exist and those binaries + are compatible with the TensorFlow runtime. (On Windows, the standard Python binaries for versions earlier than 3.5 were compiled with older compilers that do not have all of the features (e.g. C++11 support) needed to compile - TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 - on Windows, but have not yet committed to supporting that configuration.) - - - The following Python APIs are not currently implemented: - * Loading custom op libraries via `tf.load_op_library()`. In order to use your - custom op, please put the source code under the tensorflow/core/user_ops - directory, and a shape function is required (not optional) for each op. - * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not - functional. - - - The `tf.contrib` libraries are not currently included in the PIP package. - - - The following operations are not currently implemented: - * `DepthwiseConv2dNative` - * `Digamma` - * `Erf` - * `Erfc` - * `Igamma` - * `Igammac` - * `ImmutableConst` - * `Lgamma` - * `Polygamma` - * `Zeta` - - - Google Cloud Storage support is not currently implemented. The GCS library + TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 on + Windows, but have not yet committed to supporting that configuration.) + +- The following Python APIs are not currently implemented: + + * Loading custom op libraries via `tf.load_op_library()`. In order to use + your custom op, please put the source code under the + tensorflow/core/user_ops directory, and a shape function is required + (not optional) for each op. + * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not + functional. + +- The `tf.contrib` libraries are not currently included in the PIP package. + +- The following operations are not currently implemented: + + * `DepthwiseConv2dNative` + * `Digamma` + * `Erf` + * `Erfc` + * `Igamma` + * `Igammac` + * `ImmutableConst` + * `Lgamma` + * `Polygamma` + * `Zeta` + +- Google Cloud Storage support is not currently implemented. The GCS library currently depends on `libcurl` and `boringssl`, and the Windows version could use standard Windows APIs for making HTTP requests and cryptography (for OAuth). Contributions are welcome for this feature. @@ -112,97 +127,145 @@ We are actively working on improving CMake and Windows support, and addressing these limitations. We would appreciate pull requests that implement missing ops or APIs. -CMake GUI build (all platforms) -================================== -Install from CMake GUI would be a convenient way to generate C++ build projects. The software supports Windows, MacOS and Linux, while the posix platform provides an extra ccmake binary to run command line GUI. Both working principal of cmake, ccmake and cmake-gui are the same, the only difference is by providing suitable interface for project configuration and dependency setting. - -0. Pre-buid checklist: - The following binary/libraries should be setted in system path, otherwise you need to set manualy via cmake. - * Compiler (GCC for Linux, MSVC for Windows) - * Make sure compiler directory has been set to system path - * CUDA 9.0 (GPU build) - * CUDNN (GPU build) - * NCCL (GPU build on Linux) - * SWIG (python binding) - * Perl (required if you need ssl support, optional) - * Go (required if you need ssl support, optional) - * NASM/YASM (required by grpc for ssl support, optional) -1. Start CMake GUI -2. Click on `Browse Source` and direct to the the folder `/tensorflow/contrib/cmake` -3. Click on `Browse Build` and spectify a location that you want tensorflow to be build -4. Click on `Configure`, a new window will be prompted out, specify the generator mode for the project generation. For Windows, choose `Visual Studio Win64`, for Linux, choose `Unix Makefiles`, then press `Finish`. Wait for a moment, the default project dependecy would automatically generate. -5. There are a few options that you can customize your own build. **The setting here is crucial for a sucessful build, please check all items carefully.** - * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` - * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you to test build (optional) - * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't affect tensorflow function, turn it to `off` if you want a slim build. (optional) - * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if you don't need python interaface. If SWIG is not in system path, you need set it manually. (optional) - * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you want the c++ interface. (optional) - * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want GPU support. It will search CUDA and CUDNN dependecies if you have set them to system path, otherwise CMake would prompt error and request you to set it manually. (optional) - * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, this option must always be `on`. This need to be `on` for a gpu build. Reminded that Perl, Go and NASM/YASM are required for this option if you want to build grpc with offical SSL support. - * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` - * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` - * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` - * `CMAKE_INSTALL_PREFIX` is the location where the final package will be installed. You may change it to your own preferred path (optional) - -6. After changing the configuration in step 5, press `Configure` again -7. If not error is found, press `Generate` +# CMake GUI build (all platforms) + +Install from CMake GUI would be a convenient way to generate C++ build projects. +The software supports Windows, MacOS and Linux, while the posix platform +provides an extra ccmake binary to run command line GUI. Both working principal +of cmake, ccmake and cmake-gui are the same, the only difference is by providing +suitable interface for project configuration and dependency setting. + +1. Pre-buid checklist: The following binary/libraries should be setted in + system path, otherwise you need to set manualy via cmake. + * Compiler (GCC for Linux, MSVC for Windows) + * Make sure compiler directory has been set to system path + * CUDA 9.0 (GPU build) + * CUDNN (GPU build) + * NCCL (GPU build on Linux) + * SWIG (python binding) + * Perl (required if you need ssl support, optional) + * Go (required if you need ssl support, optional) + * NASM/YASM (required by grpc for ssl support, optional) +2. Start CMake GUI +3. Click on `Browse Source` and direct to the the folder + `/tensorflow/contrib/cmake` +4. Click on `Browse Build` and spectify a location that you want tensorflow to + be build +5. Click on `Configure`, a new window will be prompted out, specify the + generator mode for the project generation. For Windows, choose `Visual + Studio Win64`, for Linux, choose `Unix Makefiles`, then + press `Finish`. Wait for a moment, the default project dependecy would + automatically generate. +6. There are a few options that you can customize your own build. **The setting + here is crucial for a sucessful build, please check all items carefully.** + + * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you + to test build (optional) + * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't + affect tensorflow function, turn it to `off` if you want a slim build. + (optional) + * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if + you don't need python interaface. If SWIG is not in system path, you + need set it manually. (optional) + * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you + want the c++ interface. (optional) + * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want + GPU support. It will search CUDA and CUDNN dependecies if you have set + them to system path, otherwise CMake would prompt error and request you + to set it manually. (optional) + * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, + this option must always be `on`. This need to be `on` for a gpu build. + Reminded that Perl, Go and NASM/YASM are required for this option if you + want to build grpc with offical SSL support. + * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` + * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` + * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` + * `CMAKE_INSTALL_PREFIX` is the location where the final package will be + installed. You may change it to your own preferred path (optional) + +7. After changing the configuration in step 5, press `Configure` again + +8. If not error is found, press `Generate` #### Windows -1. Open `tensorflow.sln` in the build folder (Windows). Change build type from `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more than hours of compilation. If everything is alright, the output window would show no error. +1. Open `tensorflow.sln` in the build folder (Windows). Change build type from + `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more + than hours of compilation. If everything is alright, the output window would + show no error. ##### Python - In solution explorer, right click on `tf_python_build_pip_package` -> `build`. It will generate the wheel file in `/tf_python/dist`. Install with following command: + In solution explorer, right click on `tf_python_build_pip_package` -> + `build`. It will generate the wheel file in + `/tf_python/dist`. Install with following command: - ```pip install --upgrade tensorflow-.whl``` + `pip install --upgrade tensorflow-.whl` - ***The wheel name varies depends on you config. Change to your own wheel filename.*** + ***The wheel name varies depends on you config. Change to your own wheel + filename.*** - Reminded that some pip installation requires administrator right command prompt. + Reminded that some pip installation requires administrator right command + prompt. ##### C++ - You can directly use the build folder tree for C++ interface with cmake. If you want to do installation for api releasing, right click on `Install` -> `build`. The headers and library will be installed in the directory specify by `CMAKE_INSTALL_PREFIX` during configuration. + You can directly use the build folder tree for C++ interface with cmake. If + you want to do installation for api releasing, right click on `Install` -> + `build`. The headers and library will be installed in the directory specify + by `CMAKE_INSTALL_PREFIX` during configuration. -2. For smaller RAM computer, it is noticed that out of heap space error appears. Change to command prompt build is an alternative to do step 1. +1. For smaller RAM computer, it is noticed that out of heap space error + appears. Change to command prompt build is an alternative to do step 1. - Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command Prompt` if you are using MSVC 2017. + Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press + `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command + Prompt` if you are using MSVC 2017. ##### Python Directly build python wheel package by following command: - ```MSBuild /p:Configuration=Release ``` + `MSBuild /p:Configuration=Release + ` - Remember to change `` to the actual path of the file, it can be found at the root of build directory + Remember to change `` to the + actual path of the file, it can be found at the root of build directory Install the wheel file generated as instructed by step 1. ##### C++ interface - Build from VS native toolchain with following command: - ```MSBuild /p:Configuration=Release ``` - Headers are discretely located in the build folders. Tensorflow library can be found at `/Release`, namely `tensorflow.dll` and `tensorflow.lib`. + Build from VS native toolchain with following command: `MSBuild + /p:Configuration=Release ` + + Headers are discretely located in the build folders. Tensorflow library can + be found at `/Release`, namely `tensorflow.dll` and + `tensorflow.lib`. - * Build to install for api release (optional): - ```MSBuild /p:Configuration=Release ``` + * Build to install for api release (optional): `MSBuild + /p:Configuration=Release ` - Remember to change `` and `` to the actual path of the file, it can be found at the root of build directory. + Remember to change `` and + `` to the actual path of the file, it can be found + at the root of build directory. #### Linux/MacOS (command line GNU build) -1. Open the terminal, change working directory to the one specified in step 3. +1. Open the terminal, change working directory to the one specified in step 3. -2. Type the following command: +2. Type the following command: - ```make -sj all``` + `make -sj all` ##### Python - **Important Note** CMake generated python wheel for Linux/MacOs is currently under development. Please use bazel build. + **Important Note** CMake generated python wheel for Linux/MacOs is currently + under development. Please use bazel build. - Follow code is an expected Linux/MacOS python package build after development work is completed. + Follow code is an expected Linux/MacOS python package build after + development work is completed. ``` make -sj tf_python_build_pip_package @@ -212,52 +275,63 @@ Install from CMake GUI would be a convenient way to generate C++ build projects. ##### C++ interface - ```make -sj install``` + `make -sj install` - Where `` is the threads used for the compilation, change to any integer less or equal to your computer's maxiumum thread number. + Where `` is the threads used for the compilation, change + to any integer less or equal to your computer's maxiumum thread number. - Headers are discretely located in the build folders. Tensorflow library can be found at ``, namely `tensorflow.so` (Linux) or `tensorflow.dylib` (MacOS). + Headers are discretely located in the build folders. Tensorflow library can + be found at ``, namely `tensorflow.so` (Linux) or + `tensorflow.dylib` (MacOS). #### Start a Tensorflow C++ project with CMake -Here we assume that you have basic knowledge on gathering dependency with `CMakeLists.txt`. Here we introduce how the C++ api works with [official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). -1. Create a new working directory and create a new text file named `CMakeLists.txt` and the c++ file `main.cxx` -2. Fill in the `main.cxx` with the code provided in [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). -3. Fill in the `CMakeLists.txt` with following code: - ``` cmake - cmake_minimum_required (VERSION 2.6) - project (tf_hello) +Here we assume that you have basic knowledge on gathering dependency with +`CMakeLists.txt`. Here we introduce how the C++ api works with +[official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). + +1. Create a new working directory and create a new text file named + `CMakeLists.txt` and the c++ file `main.cxx` +2. Fill in the `main.cxx` with the code provided in + [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). +3. Fill in the `CMakeLists.txt` with following code: ``` cmake + cmake_minimum_required (VERSION 2.6) project (tf_hello) # Tensorflow + find_package(Tensorflow REQUIRED) include_directories(${TENSORFLOW_INCLUDE_DIRS}) # compiler setting required by tensorflow, to be tested on all compilers + # currently only tested on MSVC and GCC - if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) - add_definitions(-DCOMPILER_MSVC) - elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) - if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") - add_definitions(-DCOMPILER_GCC3) - else() - add_definitions(-D__GNUC__) - endif() - else() - message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by this CMakeList.txt, under development") - endif() - - add_executable(tf_hello main.cxx) - target_link_libraries(tf_hello ${TENSORFLOW_LIBRARIES}) - ``` -4. Configure the folder with cmake-gui, an error should be prompted out, requesting you to locate the folder containing `TensorflowConfig.cmake`. This file can be found at `` or `` (for those have build install in previous steps). -5. Configure again, generate the project. -6. Compile the project with `Release` config (Windows). For Linux users, just compile the project. -7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build directory to the build folder containing `tf_hello` binary. -8. Run `tf_hello` binary + if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) add_definitions(-DCOMPILER_MSVC) + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) if + (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") + add_definitions(-DCOMPILER_GCC3) else() add_definitions(-D__GNUC__) endif() + else() message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by + this CMakeList.txt, under development") endif() + + add_executable(tf_hello main.cxx) target_link_libraries(tf_hello + ${TENSORFLOW_LIBRARIES}) ``` + +4. Configure the folder with cmake-gui, an error should be prompted out, + requesting you to locate the folder containing `TensorflowConfig.cmake`. + This file can be found at `` or `` (for + those have build install in previous steps). + +5. Configure again, generate the project. + +6. Compile the project with `Release` config (Windows). For Linux users, just + compile the project. + +7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build + directory to the build folder containing `tf_hello` binary. + +8. Run `tf_hello` binary -Step-by-step Windows build (command prompt) -========================== +# Step-by-step Windows build (command prompt) 1. Install the prerequisites detailed above, and set up your environment. @@ -443,4 +517,4 @@ $ cd tensorflow $ tensorflow/tools/ci_build/ci_build.sh CMAKE tensorflow/tools/ci_build/builds/cmake.sh ``` -That's it. Dependencies included. \ No newline at end of file +That's it. Dependencies included. diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index eefa7d3f039295ab595b4233fab51e7733dd6236..b85fd48f0f34df93d9eaa31251ebe05c78b34a9e 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -31,8 +31,8 @@ if (systemlib_ABSEIL_CPP) message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") - add_custom_target(abseil_cpp_build) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + add_custom_target(abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) else (systemlib_ABSEIL_CPP) @@ -79,14 +79,11 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) endif() - ExternalProject_Add(abseil_cpp_build + ExternalProject_Add(abseil_cpp PREFIX abseil_cpp GIT_REPOSITORY ${abseil_cpp_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release - COMMAND ${CMAKE_COMMAND} --build . --config Release INSTALL_COMMAND "" CMAKE_CACHE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} @@ -99,6 +96,6 @@ else (systemlib_ABSEIL_CPP) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp_build) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index eca4f3c8c8866ff60c4ee8332a2baaa972fe3b83..e570c09ecb5e64130ed6f3375a51d74850cc3989 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) +set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 48dbfb92e65b0ed456846f83ddd5eed4d74dfe67..62005dd113bfb80fbdf23afb6d4aa5f90a1e32de 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -213,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 1630f010ab60db258b976c7bddc22ff78dccf890..e4566437c60ebb2da039e61c171fbe954a7355c9 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -58,6 +58,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/compiler/jit:xla_ops_py", + "//tensorflow/compiler/jit/ops:xla_ops_grad", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 335ac7946485f234d1af3d180283fc8daac50005..f867cd15b67dbd43650d8012b4299845af7200a8 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -23,6 +23,7 @@ import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops +from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py index 41258edd90866ae9f644a02c42dfe2dc589da998..6926c0d03fe38ab2d62cc588950c7f5a49b2aba1 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -74,8 +74,8 @@ class ConstrainedMinimizationProblem(object): if (constraints_shape.ndims is None or proxy_constraints_shape.ndims is None or - any([ii is None for ii in constraints_shape.as_list()]) or - any([ii is None for ii in proxy_constraints_shape.as_list()])): + any(ii is None for ii in constraints_shape.as_list()) or + any(ii is None for ii in proxy_constraints_shape.as_list())): raise ValueError( "constraints and proxy_constraints must have fully-known shapes") if constraints_shape != proxy_constraints_shape: diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 656633f0bf21a4d46cb85547241ef0fd42807ed6..40e159b8fcbd1864284e208cb15d9ed96119f840 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -38,12 +38,12 @@ tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( [unary_scores, sequence_lengths, transition_params, train_op]) for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, tf_sequence_lengths): -# Remove padding. -tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] + # Remove padding. + tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] -# Compute the highest score and its tag sequence. -tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( - tf_unary_scores_, tf_transition_params) + # Compute the highest score and its tag sequence. + tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( + tf_unary_scores_, tf_transition_params) """ from __future__ import absolute_import diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 670b54943277806c47bfd6c6bc9b345db4bb1448..8d35622e393e15a2f2dfea7c75ad2c9f48aa7150 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -42,10 +42,11 @@ tf_custom_op_py_library( cuda_py_test( name = "cudnn_rnn_ops_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ ":cudnn_rnn_py", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", @@ -61,7 +62,7 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - shard_count = 6, + shard_count = 2, tags = [ "noasan", # http://b/62067814 "requires-gpu-sm35", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index ae839108ebec31b70b687e5ff3e99c7d5a9b560e..a268415f0e65206294431a537be18cadbe1a1e84 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import itertools import os import unittest +from absl.testing import parameterized import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -56,714 +62,989 @@ CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER -def _CreateModel(rnn_mode, - num_layers, - num_units, - input_size, - input_mode="linear_input", - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0.): - del input_mode - if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM: - model_fn = cudnn_rnn_ops.CudnnLSTM - elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU: - model_fn = cudnn_rnn_ops.CudnnGRU - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH: - model_fn = cudnn_rnn_ops.CudnnRNNTanh - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU: - model_fn = cudnn_rnn_ops.CudnnRNNRelu +def RunLSTM(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_c_op = variable_scope.get_variable( + "initial_c_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + + with variable_scope.variable_scope("test", initializer=initializer): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + num_units, num_units * 4], + dtype=dtype) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + + # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. + cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + outputs_op, state_tuple_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=rnn_cell_impl.LSTMStateTuple( + h=initial_h_op, c=initial_c_op), + dtype=dtype, + time_major=True, + scope=None) + + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + cu_initial_c_op, + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + # Remove the trivial 1st dimension. + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0), + h=array_ops.squeeze(cu_h_op, axis=0)) + + if is_training: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + + (cu_inp_grad_op, cu_hgrad_op, + cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, + [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + # Remove the trivial 1st dimension + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + cu_wgrad_op = cu_wgrad_op[0] + cu_bgrad_op = cu_bgrad_op[0] + # cudnn lstm has 2 biases each gate. When converting to tf canonical format, + # the two biases are summed into one. Thus here bias gradient should be + # halved when comparing with tf lstm. + cu_bgrad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "state_grad: %s" % str(state_grad)) + logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: - raise ValueError("Invalid rnn_mode: %s" % rnn_mode) - return model_fn( - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - - -def _CreateParamsSavable(params, - model, - base_variable_scope=None, - name="params_canonical"): - """Create a RNNParamsSaveable for the weight and bias parameters. + outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + return outputs, cu_outputs, state_tuple, cu_state_tuple + + +# Basic set of RNN configs to test. They can be further extended in relevant +# test (e.g. adding num_dirs). +NAMED_RNN_TESTCASES = ({ + "testcase_name": "xsmall", + "num_units": 1, + "input_size": 1, + "batch_size": 1, + "time": 1, + "num_layers": 1, +}, { + "testcase_name": "small", + "num_units": 4, + "input_size": 4, + "batch_size": 4, + "time": 4, + "num_layers": 1, +}, { + "testcase_name": "medium", + "num_units": 128, + "input_size": 64, + "batch_size": 8, + "time": 16, + "num_layers": 1, +}, { + "testcase_name": "large", + "num_units": 128, + "input_size": 128, + "batch_size": 16, + "time": 32, + "num_layers": 1, +}) + + +def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): + """Expands testcase with new config dimensions. + + Example: + inputs = ( + {'testcase_name': 'test1', 'gender': 'male'} + {'testcase_name': 'test2', 'gender': 'female'} + ) + remove_keys: empty + extra_configs = { + 'age': [40, 80] + 'height': [5, 6] + } + + Returns: + ( + {'testcase_name': 'test1_age_40_height_5','gender': 'male', 'age': + 40,'height': 5} + {'testcase_name': 'test1_age_40_height_6', 'gender': 'male', 'age': 40, + 'height': 6} + {'testcase_name': 'test1_age_80_height_5', 'gender': 'male', 'age': 80, + 'height': 5} + {'testcase_name': 'test1_age_80_height_6', 'gender': 'male', 'age': 80, + 'height': 6} + + {'testcase_name': 'test2_age_40_height_5', 'gender': 'female', 'age': + 40, + 'height': 5} + {'testcase_name': 'test2_age_40_height_6', 'gender': 'female', 'age': + 40, + 'height': 6} + {'testcase_name': 'test2_age_80_height_5', 'gender': 'female', 'age': + 80, + 'height': 5} + {'testcase_name': 'test2_age_80_height_6', 'gender': 'female', 'age': + 80, + 'height': 6} + ) Args: - params: a Variable for weight and bias parameters. - model: a CudnnRNN model. - base_variable_scope: a string, prefix of names of saved variables. - name: a string, name of the RNNParamsSaveable object. + inputs: A list of dictionary, each being a testcase. + *remove_keys: A list of keys into testcase which are not needed in new + testcases. + **extra_configs: A dict of new test dimension and applicable values in that + dimension. + Returns: - a RNNParamsSaveable object. + A list of dictionary with expanded test cases. """ - if model._rnn_mode == CUDNN_LSTM: - fn = cudnn_rnn_ops.CudnnLSTMSaveable - elif model._rnn_mode == CUDNN_GRU: - fn = cudnn_rnn_ops.CudnnGRUSaveable - elif model._rnn_mode == CUDNN_RNN_TANH: - fn = cudnn_rnn_ops.CudnnRNNTanhSaveable - elif model._rnn_mode == CUDNN_RNN_RELU: - fn = cudnn_rnn_ops.CudnnRNNReluSaveable - params_saveable = fn( - params, - model.num_layers, - model.num_units, - model.input_size, - model.input_mode, - model.direction, - scope=base_variable_scope, - name=name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) - return params_saveable - - -def _MinLSTMParamSize(num_layers, - num_units, - input_size, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION): - if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units - all_biases = 8 * num_layers * num_units - return first_layer_weights + higher_layer_weights + all_biases - elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = (num_layers - 1) * ( - 4 * 2 * num_units * num_units + 4 * num_units**2) - all_biases = 8 * num_layers * num_units - return 2 * (first_layer_weights + higher_layer_weights + all_biases) - else: - raise ValueError("%s direction is not supported.") + res = [] + ordered_extra_configs = collections.OrderedDict(extra_configs) + keys = ordered_extra_configs.keys() + # A list of list of configs. + # The outer loop is iterating keys, the innner is values of one key. + combined_kv = [[(k, v) for v in ordered_extra_configs[k]] for k in keys] + logging.info("combined_kv: %s", combined_kv) + for inp in inputs: + # Each inp is a dict + for config in itertools.product(*combined_kv): + new_inp = dict(inp) + # config is a list in the form of [(k_i, v_j), (k_p, v_q), ...] + suffix = ["%s_%s" % (p[0], str(p[1])) for p in config] + suffix = "_".join(suffix) + new_inp["testcase_name"] += "_" + suffix + for k, v in config: + new_inp[k] = v + # Remove not used keys from the new test case. + if remove_keys: + if not isinstance(remove_keys, (list, tuple)): + remove_keys = [remove_keys] + for k in remove_keys: + new_inp.pop(k, None) + logging.info("new_inp: %s", new_inp) + res.append(new_inp) + # Dedup, necessary if `remove_keys` is set. + return [dict(t) for t in {tuple(d.items()) for d in res}] -class CudnnRNNTestSaveRestore(TensorFlowTestCase): - def _CompareWeights(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - for lw, rw in zip(lhs, rhs): - self.assertAllEqual(lw, rw) +class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): - def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): - self.assertEqual(len(lhs), len(rhs)) - if rnn_mode == CUDNN_LSTM: - num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_GRU: - num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_RNN_TANH: - num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER - else: - num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER - num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 - num_params_per_layer *= num_dirs - self.assertEqual(num_params_per_layer * num_layers, len(lhs)) - - for i in range(num_layers): - layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - if direction == CUDNN_RNN_UNIDIRECTION: - self._CompareSingleLayerBiases(layer_lhs, layer_rhs) - else: - size = len(layer_lhs) - fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] - fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] - self._CompareSingleLayerBiases(fw_lhs, fw_rhs) - self._CompareSingleLayerBiases(bw_lhs, bw_rhs) - - def _CompareSingleLayerBiases(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - - lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] - lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] - self.assertEqual(len(lf_lhs), len(rt_lhs)) - self.assertEqual(len(lf_rhs), len(rt_rhs)) - - sum_lhs, sum_rhs = [], [] - for lf, rt in zip(lf_lhs, rt_lhs): - sum_lhs.append(lf + rt) - for lf, rt in zip(lf_rhs, rt_rhs): - sum_rhs.append(lf + rt) - self.assertEqual(len(sum_lhs), len(sum_rhs)) - for lf, rt in zip(sum_lhs, sum_rhs): - self.assertAllEqual(lf, rt) + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, + state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers) - def _testSaveRestoreVariable(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - params = variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) - saveable = _CreateParamsSavable(params, model) - weights, biases = saveable.format_converter._opaque_to_cu_canonical( - saveable._variables) - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + for s, cu_s in zip(state_tuple, cu_state_tuple): + self.assertAllClose(s, cu_s, rtol=rtol, atol=atol) + for sg, cu_sg in zip(state_grad, cu_state_grad): + self.assertAllClose(sg, cu_sg, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) + self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - weights_v, biases_v = sess.run([weights, biases]) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) - sess.run(reset_params) - saver.restore(sess, save_path) - weights_v_restored, biases_v_restored = sess.run([weights, biases]) - - self._CompareWeights(weights_v, weights_v_restored) - self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - names = ["rnn_1", "rnn_2"] - param_vars = [ - variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) for name in names - ] - saveables = [] - for name, params in zip(names, param_vars): - saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0].format_converter._opaque_to_cu_canonical( - saveables[0]._variables) - weights2, biases2 = saveables[1].format_converter._opaque_to_cu_canonical( - saveables[1]._variables) - reset_params = [ - state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) for params in param_vars - ] - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session(use_gpu=True, - graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - weights1_v, biases1_v = sess.run([weights1, biases1]) - weights2_v, biases2_v = sess.run([weights2, biases2]) - - sess.run(reset_params) - saver.restore(sess, save_path) - weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) - weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) - - self._CompareWeights(weights1_v, weights1_v_restored) - self._CompareWeights(weights2_v, weights2_v_restored) - self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, - direction) - self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreOutput(self, rnn_mode, direction, dtype): - with ops.Graph().as_default(): - num_layers = 2 - num_units = 7 - input_size = 7 - seq_length = 10 - batch_size = 5 - dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 - model = _CreateModel( - rnn_mode, + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, num_layers, + is_training=False) + + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, num_units, input_size, - direction=direction, - dtype=dtype) - params_size_t = model.params_size() - params = variables.VariableV1( - array_ops.ones([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - _CreateParamsSavable(params, model) - save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) - np.random.seed(1234) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - input_data = constant_op.constant( - np.random.randn(seq_length, batch_size, input_size), dtype=dtype) - input_h = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - if has_input_c: - input_c = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - outputs = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - outputs = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - total_sum = sum(map(math_ops.reduce_sum, outputs)) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run(total_sum) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - sess.run(reset_params) - saver.restore(sess, save_path) - total_sum_v_restored = sess.run(total_sum) - self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSaveRestore(self): - rnn_modes = [ - cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU, - cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - dtype_list = [dtypes.float32, dtypes.float64] - for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions, - dtype_list): - self._testSaveRestoreVariable(rnn_mode, direction, dtype) - self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype) - self._testSaveRestoreOutput(rnn_mode, direction, dtype) - - -class CudnnRNNTestParamsSize(TensorFlowTestCase): - - def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, - direction): - logging.info("Testing one lstm param size with config: %s", locals()) - min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size, - direction) - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - num_layers, + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + + +def RunGRU(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + with variable_scope.variable_scope("test", initializer=initializer): + gate_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/kernel", + shape=[input_size + num_units, num_units * 2], + dtype=dtype) + gate_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/bias", + shape=[num_units * 2], + dtype=dtype) + candidate_inp_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/kernel", + shape=[input_size, num_units], + dtype=dtype) + candidate_inp_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/bias", + shape=[num_units], + dtype=dtype) + candidate_hid_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/kernel", + shape=[num_units, num_units], + dtype=dtype) + candidate_hid_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/bias", + shape=[num_units], + dtype=dtype) + + cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) + outputs_op, h_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=initial_h_op, + dtype=dtype, + time_major=True, + scope=None) + + ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] + bs = [gate_bias, candidate_inp_bias, candidate_hid_bias] + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + array_ops.zeros_like(cu_initial_h_op), # not used + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_GRU) + + if is_training: + (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, + cib_grad_op, chb_grad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op] + ws + bs) + + (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op) = cu_wgrad_op + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) = cu_bgrad_op + # cudnn gru has 2 biases for reset and update gates. When converting to tf + # canonical format, the two biases are summed into one. Thus here relevant + # bias gradient should be halved before comparing with tf gru. + cu_gb_grad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, h, inp_grad, hgrad, wgrad, bgrad = sess.run([ + outputs_op, h_op, inp_grad_op, hgrad_op, + (gk_grad_op, cik_grad_op, chk_grad_op), + (gb_grad_op, cib_grad_op, chb_grad_op) + ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ]) + # Remove the trivial 1st dimension + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "hgrad: %s" % hgrad) + logging.vlog(1, "cu_hgrad: %s" % cu_hgrad) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) + else: + outputs, h = sess.run([outputs_op, h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + # Remove the trivial 1st dimension. + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + return outputs, cu_outputs, h, cu_h + + +class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): + + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, num_units, input_size, batch_size, time, num_layers) + + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + self.assertAllClose(hgrad, cu_hgrad, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + for bg, cu_bg in zip(bgrad, cu_bgrad): + self.assertAllClose(bg, cu_bg, rtol=rtol, atol=atol) + for wg, cu_wg in zip(wgrad, cu_wgrad): + self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( num_units, input_size, - direction=direction) - params_size = model.params_size() - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size) - self.assertLessEqual(min_params_size, params_size_v) + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSize(self): - test_configs = [ - [4, 200, 200], - [4, 200, 300], - [4, 200, 100], - [1, 100, 200], - [2, 200, 100], - [3, 200, 400], - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - for (config, direction) in itertools.product(test_configs, directions): - num_layers, num_units, input_size = config - with ops.Graph().as_default(): - self._testOneLSTMParamsSize(num_layers, num_units, input_size, - direction) + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False) + self.assertAllClose(outputs, cu_outputs) + self.assertAllClose(h, cu_h) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSizeShape(self): - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - constant_op.constant([4]), 200, 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, constant_op.constant([200]), 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) + + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + # Hand-picked dropouts are used below (0. and 1.) + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_h2) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + self.assertAllClose(cu_h[0], cu_h2[0]) + + +class CudnnParamsFormatConverterTest(TensorFlowTestCase, + parameterized.TestCase): + """Class for testing various format converters.""" + + def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + w = constant_op.constant( + np.random.rand(input_size + num_units, 4 * num_units), + dtype=dtypes.float32) + b = constant_op.constant( + np.random.rand(4 * num_units), dtype=dtypes.float32) + ws.append(w) + bs.append(b) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( cudnn_rnn_ops.CUDNN_LSTM, - 4, 200, constant_op.constant([200]), - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - _ = model.params_size() + num_layers, + num_units, + input_size, + direction=direction) + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) -class CudnnRNNTestInference(TensorFlowTestCase): + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) - def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, - expected, tolerance): - random_seed.set_random_seed(5678) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - input_mode="auto_select", - direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION), - dropout=dropout) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - params_size_t = model.params_size() - input_data = array_ops.ones([seq_length, batch_size, input_size]) - input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.VariableV1( - array_ops.ones([params_size_t]), validate_shape=False) - if has_input_c: - input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - output, output_h = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run([total_sum]) + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + np.sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) - self.assertAllClose( - total_sum_v[0], expected, atol=tolerance, rtol=tolerance) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_lstm(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleInference(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "expected": 231833.22, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "expected": 130688, - "tolerance": 1e-2, - "shape": { - "num_layers": 2, - "num_units": 8, - "input_size": 4, - "batch_size": 4, - "seq_length": 2, - "dir_count": 1, - }, - }, - ] - # Cudnn scales result for dropout during training, therefore dropout has no - # impact for inference results. - # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most - # demonstrative of the dropout-invariant nature of CudnnRnn.) - dropouts = [0., 0.5, 1.] - for (config, dropout) in itertools.product(test_configs, dropouts): - rnn_mode = config["rnn_mode"] - expected = config["expected"] - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleInference( - rnn_mode, shape["num_layers"], shape["num_units"], - shape["input_size"], shape["batch_size"], shape["seq_length"], - shape["dir_count"], dropout, expected, tolerance) - - -class CudnnRNNTestTraining(TensorFlowTestCase): - - def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): - # Gradient checking runs two forward ops with almost the same input. Need to - # make sure the drop patterns across the two runs are the same. - logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - random_seed.set_random_seed(5678) - direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - params_size_t = model.params_size() - input_data = variables.VariableV1( - random_ops.random_uniform( - [seq_length, batch_size, input_size], dtype=dtype), - dtype=dtype) - input_h = variables.VariableV1( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - params = variables.VariableV1( - random_ops.random_uniform([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - if has_input_c: - input_c = variables.VariableV1( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params) - else: - output, output_h = model( - input_data=input_data, input_h=input_h, params=params) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size_t) - inputs_and_shapes = [ - (input_data, [seq_length, batch_size, input_size]), - (input_h, [num_layers * dir_count, batch_size, num_units]), - (params, [params_size_v]), - ] - if has_input_c: - inputs_and_shapes.append( - (input_c, [num_layers * dir_count, batch_size, num_units]),) - sess.run(variables.global_variables_initializer()) - all_inputs = [entry[0] for entry in inputs_and_shapes] - all_shapes = [entry[1] for entry in inputs_and_shapes] - - err = gradient_checker.compute_gradient_error( - all_inputs, all_shapes, total_sum, [1], delta=delta) - - self.assertLess(err, tolerance) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + def test_lstm_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + def _test_gru_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + gate_kernel = constant_op.constant( + np.random.rand(input_size + num_units, num_units * 2), + dtype=dtypes.float32) + gate_bias = constant_op.constant( + np.random.rand(num_units * 2), dtype=dtypes.float32) + candidate_inp_kernel = constant_op.constant( + np.random.rand(input_size, num_units), dtype=dtypes.float32) + candidate_inp_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + candidate_hid_kernel = constant_op.constant( + np.random.rand(num_units, num_units), dtype=dtypes.float32) + candidate_hid_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + ws.extend([gate_kernel, candidate_inp_kernel, candidate_hid_kernel]) + bs.extend([gate_bias, candidate_inp_bias, candidate_hid_bias]) + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + cudnn_rnn_ops.CUDNN_GRU, + num_layers, + num_units, + input_size, + direction=direction) + + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) + + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_gru(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def DISABLED_testSimpleTraining(self): - # TODO(jamesqin): fix b/117989214 - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float32, - "tolerance": 1.5e-2, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float32, - "tolerance": 4e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float32, - "tolerance": 5e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float32, - "tolerance": 5e-1, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - ] - dropouts = [0., 0.5, 1.] - dir_counts = [1] - for config, dropout, dir_count in itertools.product(test_configs, dropouts, - dir_counts): - rnn_mode = config["rnn_mode"] - dtype = config.get("dtype", dtypes.float32) - delta = config.get("delta", 1e-3) - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, tolerance) + def test_gru_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + +class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): + """Class for testing various Cudnn Rnn SaveableObjects.""" + + def _create_opaque_param(self, + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name=None): + param_size_t = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + rnn_mode, num_layers, num_units, input_size, direction=direction) + init_val = random_ops.random_uniform([param_size_t]) + return variable_scope.get_variable( + name or "opaque_param", initializer=init_val, validate_shape=False) + + def _create_saveable(self, opaque_param, rnn_mode, num_units, input_size, + num_layers, direction): + if rnn_mode == CUDNN_LSTM: + fn = cudnn_rnn_ops.CudnnLSTMSaveable + elif rnn_mode == CUDNN_GRU: + fn = cudnn_rnn_ops.CudnnGRUSaveable + elif rnn_mode == CUDNN_RNN_TANH: + fn = cudnn_rnn_ops.CudnnRNNTanhSaveable + elif rnn_mode == CUDNN_RNN_RELU: + fn = cudnn_rnn_ops.CudnnRNNReluSaveable + saveable = fn( + opaque_param, num_layers, num_units, input_size, direction=direction) + return saveable + + def _compare_weights(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _compare_biases(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lf, rt in zip(lhs, rhs): + self.assertAllEqual(lf, rt) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_variable(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size, + num_layers, direction) + saveable = self._create_saveable(opaque_param, rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + weights_op, biases_op = saveable.format_converter.opaque_to_tf_canonical( + saveable._variables) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_op = state_ops.assign(opaque_param, + array_ops.zeros_like(opaque_param)) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + weights, biases = sess.run([weights_op, biases_op]) + + # Reset the opaque param value + sess.run(reset_op) + # Assert reset happened. + weights_z, biases_z = sess.run([weights_op, biases_op]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_op, biases_op]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_params = [] + saveables = [] + num_opaque_params = 2 + for i in range(num_opaque_params): + opaque_params.append( + self._create_opaque_param( + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name="opaque_param_%d" % i)) + saveable = self._create_saveable(opaque_params[i], rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saveables.append(saveable) + + weights_ops, biases_ops = [], [] + for i in range(num_opaque_params): + weights_op, biases_op = ( + saveables[i].format_converter.opaque_to_tf_canonical( + saveables[i]._variables)) + weights_ops.append(weights_op) + biases_ops.append(biases_op) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_ops = [] + for i in range(num_opaque_params): + reset_ops.append( + state_ops.assign(opaque_params[i], + array_ops.zeros_like(opaque_params[i]))) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + for i in range(num_opaque_params): + weights, biases = sess.run([weights_ops[i], biases_ops[i]]) + + # Reset the opaque param value + sess.run(reset_ops[i]) + + # Assert reset happened. + weights_z, biases_z = sess.run([weights_ops[i], biases_ops[i]]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_ops[i], biases_ops[i]]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) if __name__ == "__main__": diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 1954f6717bbebd803b0ec45992b43cf68f5d72a0..7e1b4062ce435f3ab4216e90b4f5fcbab984c1dc 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -536,7 +536,9 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): save_path = os.path.join(self.get_temp_dir(), "save-restore-variable-test") saver = saver_lib.Saver() - weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + weights, biases = ( + model.rnn.saveable.format_converter._opaque_to_cu_canonical( + model.rnn.saveable._variables)) opaque_params = rnn.trainable_variables[0] # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save # Cudnn vars in canonical format. @@ -583,8 +585,12 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): dtype=dtype) opaque_params = (model1.rnn.trainable_variables[0], model2.rnn.trainable_variables[0]) - weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() - weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + saveable1 = model1.rnn.saveable + weights1, biases1 = saveable1.format_converter._opaque_to_cu_canonical( + saveable1._variables) + saveable2 = model1.rnn.saveable + weights2, biases2 = saveable2.format_converter._opaque_to_cu_canonical( + saveable2._variables) reset_params = [ state_ops.assign(params, array_ops.zeros_like(params, dtype=dtype)) @@ -1039,8 +1045,8 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): # Min param size estimate = sum(weights.size) + sum(biases.size) min_params_size = ( - np.sum(list(map(np.prod, rnn.canonical_weight_shapes))) + - np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + sum(map(np.prod, rnn.canonical_weight_shapes)) + + sum(sp[0] for sp in rnn.canonical_bias_shapes)) opaque_params = rnn.trainable_variables[0] with self.test_session(use_gpu=True, graph=ops.get_default_graph()): diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 8bbcc7cd0397a5339a69e4e44528f0e56584043a..8e25637ed91a1559b321ea96efbfaa2910f67158 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -21,6 +21,7 @@ from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -322,7 +323,7 @@ class _CudnnRNN(base_layer.Layer): raise ValueError("The last dimension of the inputs to `CudnnRNN` " "should be defined. Found `None`.") self._input_size = input_shape[-1].value - self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + self.input_spec = input_spec.InputSpec(ndim=3, axes={-1: self._input_size}) self._set_scope(None) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index d06d0c6bdaa113089c4d4239a6d4ed216ddd01a8..1ce29b42d52ff67477161278ed11016c2e73041d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -738,7 +738,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): self._variables, opaque_params, validate_shape=False) def _checkpointable_save(self, save_buffer): - weights, biases = self.format_converter.opaque_params_to_tf_canonical( + weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a1928cf226010670b90a5d574579e0411..6c5f8c6b00975b3fba041271309a93cecd9f5057 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db159755ac2d741bcdbce9ec646d928e..b9840b1ff1a3df5a05db0e64f436637220f49f80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a786232252432481566e3cde23e9310df172cc..2527706709fae8e459aca3489324d4db3c784be6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 34dc2379d0cb38f8f6962fa42efe21b793bc8d65..0fb406f1167053a128646c5c692986b0ce016f1e 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -188,8 +188,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:function", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4601376dff47e161962e92678883039c4b88bab7..c0152156a1ba70297adb7054622b15ca04f859cd 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -21,10 +21,9 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation @@ -355,7 +354,7 @@ def read_batch_features(file_pattern, shuffle=randomize_input, num_epochs=num_epochs, shuffle_buffer_size=capacity) - iterator = dataset.make_one_shot_iterator() + iterator = dataset_ops.make_one_shot_iterator(dataset) outputs = iterator.get_next() return outputs @@ -379,15 +378,13 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. @@ -398,18 +395,10 @@ class LMDBDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_experimental_dataset_ops.experimental_lmdb_dataset( - self._filenames, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + self._filenames, **dataset_ops.flat_structure(self)) @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index bcc383587c54bd89502313f9328bc06c49046a87..5c6ee6bfdc7167d14b292f8f763adafca4e3a72c 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,11 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import deprecation @@ -40,8 +39,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") + input_structure = structure.convert_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + self._structure = input_structure._batch(None) # pylint: disable=protected-access + def _as_variant_tensor(self): - return gen_dataset_ops.slide_dataset( + return ged_ops.experimental_sliding_window_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, @@ -49,20 +53,8 @@ class _SlideDataset(dataset_ops.UnaryDataset): **dataset_ops.flat_structure(self)) @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - input_shapes = self._input_dataset.output_shapes - return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) - for s in nest.flatten(self._input_dataset.output_shapes) - ]) - - @property - def output_types(self): - return self._input_dataset.output_types + def _element_structure(self): + return self._structure @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index a87a5624c88d1d0af10055261dad55937ed6aeb0..3ecd755d86f6be47910aebbdb46d335d165427d8 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -26,7 +26,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", - "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", @@ -35,6 +34,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_coordinator", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index a938f8629d8210b4b512338a040340f21d3ef594..8a8dc159ade6f2a4a9b5ec29055ea4848492b29f 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -134,7 +134,7 @@ def model_fn(features, labels, mode): return tf.estimator.EstimatorSpec(mode, loss=loss) if mode == tf.estimator.ModeKeys.TRAIN: - train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) ``` @@ -248,19 +248,17 @@ Let's use the same example for multi-worker. We'll start a cluster with 3 workers doing synchronous all-reduce training. In the following code snippet, we start multi-worker training using `tf.estimator.train_and_evaluate`: - ```python def model_main(): - estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig(train_distribute=distribution) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) ``` - **Note**: You don't have to set "TF\_CONFIG" manually if you use our provided Kubernetes template. @@ -327,13 +325,13 @@ start training. On your laptop, you can run ```python -estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig( experimental_distribute=tf.contrib.distribute.DistributeConfig( train_distribute=distribution, remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]})) +estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index ab2f221dc6486666e914deb19dd56c7687606e2f..8ec73654e30e4967f318c558ba94301e84a206e4 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -25,13 +25,13 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy -from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server from tensorflow.python.training.distribute import * @@ -46,6 +46,7 @@ _allowed_symbols = [ 'CrossDeviceOps', 'DistributeConfig', 'DistributionStrategy', + 'DistributionStrategyExtended', 'MirroredStrategy', 'Monitor', 'MultiWorkerAllReduce', @@ -62,6 +63,7 @@ _allowed_symbols = [ 'get_loss_reduction', 'get_replica_context', 'has_distribution_strategy', + 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', 'UpdateContext', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 4094e52169aab0b46da4f62087ddac4f750039a4..4c9c35da5a36aa8149d15c8d1c25e4dfaa6a07c1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -16,45 +16,26 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # TODO(priyag): Figure out testonly issues that are preventing us from # including our tests in pip for now. -py_library( - name = "values", - srcs = ["values.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":input_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:base", - "@six_archive//:six", - ], -) - cuda_py_test( name = "values_test", srcs = ["values_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:errors", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:device_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", - "//tensorflow/python:device_util", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], @@ -68,25 +49,9 @@ py_library( srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":shared_variable_creator", - ":values", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:values", ], ) @@ -95,16 +60,17 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -116,7 +82,7 @@ cuda_py_test( ":combinations", ":multi_worker_test_base", ":parameter_server_strategy", - ":values", + ":strategy_test_lib", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -127,10 +93,12 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:session", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -145,12 +113,13 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":values", - "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:distribute", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -161,16 +130,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":cross_tower_utils", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -187,11 +156,11 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:layers", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -212,10 +181,10 @@ py_library( ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", "@absl_py//absl/testing:parameterized", ], @@ -233,28 +202,6 @@ py_test( ], ) -py_test( - name = "mirrored_strategy_test", - srcs = ["mirrored_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":mirrored_strategy", - ":multi_worker_test_base", - ":strategy_test_lib", - "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], @@ -270,35 +217,32 @@ py_test( ], ) +# TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", ":strategy_test_lib", - "//tensorflow/python:distribute", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 5, tags = [ "guitar", - "no_pip", "multi_and_single_gpu", - # Do not perform the extra analysis on this test, because it is already - # performed for the `:mirrored_strategy_test` target. - "no_oss", - "noasan", - "notap", - "notsan", + "no_pip", ], ) @@ -337,12 +281,15 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":one_device_strategy", - ":values", "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", ], ) @@ -352,7 +299,6 @@ cuda_py_test( additional_deps = [ ":collective_all_reduce_strategy", ":combinations", - ":cross_tower_utils", ":multi_worker_test_base", ":strategy_test_lib", "@absl_py//absl/testing:parameterized", @@ -368,6 +314,7 @@ cuda_py_test( "//tensorflow/python:layers", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -469,6 +416,7 @@ cuda_py_test( "multi_and_single_gpu", "no_oss", # http://b/119349471 "no_pip", + "tf_integration_test", ], ) @@ -476,28 +424,18 @@ cuda_py_test( name = "keras_optimizer_v2_test", srcs = ["keras_optimizer_v2_test.py"], additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/feature_column", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", "no_oss", # http://b/119349471 "no_pip", + "tf_integration_test", ], ) cuda_py_test( name = "estimator_training_test", - size = "large", srcs = ["estimator_training_test.py"], additional_deps = [ ":collective_all_reduce_strategy", @@ -508,7 +446,9 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", @@ -516,7 +456,7 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:summary", ], - shard_count = 5, + shard_count = 48, tags = [ "multi_and_single_gpu", "no_pip", @@ -524,6 +464,7 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "no_oss", # http://b/119349471 ], ) @@ -599,52 +540,16 @@ cuda_py_test( ], ) -py_library( - name = "shared_variable_creator", - srcs = ["shared_variable_creator.py"], - visibility = ["//tensorflow:internal"], -) - -py_test( - name = "shared_variable_creator_test", - srcs = ["shared_variable_creator_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":shared_variable_creator", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:test", - ], -) - -py_library( - name = "cross_tower_utils", - srcs = ["cross_tower_utils.py"], - srcs_version = "PY2AND3", - deps = [ - ":values", - "//tensorflow/contrib/all_reduce:all_reduce_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:device", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - cuda_py_test( - name = "cross_tower_utils_test", - srcs = ["cross_tower_utils_test.py"], + name = "cross_device_utils_test", + srcs = ["cross_device_utils_test.py"], additional_deps = [ ":combinations", - ":cross_tower_utils", "@absl_py//absl/testing:parameterized", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -653,40 +558,20 @@ cuda_py_test( ], ) -py_library( - name = "cross_tower_ops", - srcs = ["cross_tower_ops.py"], - srcs_version = "PY2AND3", - deps = [ - ":cross_tower_utils", - ":values", - "//tensorflow/python:array_ops", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "@six_archive//:six", - ], -) - cuda_py_test( - name = "cross_tower_ops_test", - srcs = ["cross_tower_ops_test.py"], + name = "cross_device_ops_test", + srcs = ["cross_device_ops_test.py"], additional_deps = [ ":combinations", - ":cross_tower_ops", ":multi_worker_test_base", ":mirrored_strategy", - ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -696,37 +581,6 @@ cuda_py_test( ], ) -py_library( - name = "input_ops", - srcs = ["input_ops.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - ], -) - -cuda_py_test( - name = "input_ops_test", - srcs = ["input_ops_test.py"], - additional_deps = [ - ":input_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:errors", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:io_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python:util", - ], - tags = [ - "no_pip", - ], -) - py_library( name = "keras_test_lib", testonly = 1, @@ -737,6 +591,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", "//third_party/py/numpy", @@ -766,7 +621,6 @@ py_library( srcs = ["metrics_v1_test.py"], deps = [ ":combinations", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index d38bdb592a303d23871b48d80868917efc01dcd1..31bd0e996a247a2fc01405fb3b8172a40853d698 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -43,7 +43,9 @@ class CheckpointUtilsWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], in_replica_mode=[True, False], mode=["graph"])) def testInitFromCheckpoint(self, distribution, in_replica_mode): diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index efa99d1fc52e8facfaeb61f98b5e649a18f6a3cf..5c50a20490482856becedf7b1379d2a0583d9a11 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,12 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -32,7 +36,7 @@ from tensorflow.python.platform import tf_logging as logging # TODO(yuefengz): support in-graph replication. -class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): +class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. It is similar to the MirroredStrategy but it uses collective ops for @@ -53,6 +57,17 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): num_gpus_per_worker: number of local GPUs or GPUs per worker, the default is 0 meaning CPU only. """ + super(CollectiveAllReduceStrategy, self).__init__( + CollectiveAllReduceExtended(self, num_gpus_per_worker)) + + +class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): + """Implementation of CollectiveAllReduceStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + distribute_lib.DistributionStrategyExtended.__init__( + self, container_strategy) + self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) @@ -62,19 +77,19 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._num_workers = 1 if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = ["/device:CPU:0"] + local_devices = ("/device:CPU:0",) + self._worker_device = device_util.canonicalize("/device:CPU:0") - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=1, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) self._cluster_spec = None self._task_type = None @@ -89,13 +104,12 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: + if task_type not in ("chief", "worker"): raise ValueError( "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len( - cluster_spec.as_dict().get("chief", [])) + self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker` or `chief` tasks can be found in " "`cluster_spec`.") @@ -103,22 +117,21 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) - worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: - local_devices = [ - "%s/device:GPU:%d" % (worker_device, i) + local_devices = tuple( + "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = [worker_device] + local_devices = (self._worker_device,) - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -202,17 +215,40 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return mirrored_strategy._create_mirrored_variable( devices, _real_mirrored_creator, *args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec is None: + input_pipeline_id = 0 + else: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + input_context = distribute_lib.InputContext( + num_input_pipelines=self._num_workers, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + + return values.InputFunctionIterator( + input_fn, [(self._worker_device, self._devices)], [input_context]) + + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the object. Args: @@ -232,13 +268,15 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) - if not session_config: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) # Enable the scoped allocator optimization for CollectiveOps. This # optimization converts many small all-reduces into fewer larger # all-reduces. - rewrite_options = session_config.graph_options.rewrite_options + rewrite_options = updated_config.graph_options.rewrite_options rewrite_options.scoped_allocator_optimization = ( rewriter_config_pb2.RewriterConfig.ON) # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = @@ -248,7 +286,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") if not self._cluster_spec: - return + return updated_config assert self._task_type assert self._task_id is not None @@ -256,26 +294,28 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): # Collective group leader is needed for collective ops to coordinate # workers. if "chief" in self._cluster_spec.jobs: - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:chief/replica:0/task:0") else: if "worker" not in self._cluster_spec.jobs: raise ValueError( "You must have `chief` or `worker` jobs in the `cluster_spec`.") - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:worker/replica:0/task:0") # The device filters prevent communication between workers. - del session_config.device_filters[:] - session_config.device_filters.append( + del updated_config.device_filters[:] + updated_config.device_filters.append( "/job:%s/task:%d" % (self._task_type, self._task_id)) + return updated_config + @property - def between_graph(self): + def experimental_between_graph(self): return True @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -287,6 +327,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return self._is_chief @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._devices) * self._num_workers + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index e3d919dd0d482f49d9a934c879e9adad25c03f86..8a9e583f0afaac37a2057bae9b1ed79de43d68bc 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -23,13 +23,19 @@ import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops @@ -51,9 +57,6 @@ class CollectiveAllReduceStrategyTestBase( collective_key_base = 0 def setUp(self): - self._run_options = config_pb2.RunOptions() - self._run_options.experimental.collective_graph_key = 6 - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -71,15 +74,16 @@ class CollectiveAllReduceStrategyTestBase( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution._collective_keys = collective_keys - distribution._cross_tower_ops._collective_keys = collective_keys + distribution.extended._collective_keys = collective_keys + distribution.extended._inferred_cross_device_ops._collective_keys = ( + collective_keys) if task_type and task_id is not None: return distribution, 'grpc://' + self._cluster_spec[task_type][ task_id], session_config @@ -93,7 +97,8 @@ class CollectiveAllReduceStrategyTestBase( self.cached_session(config=config, target=master_target) as sess, \ d.scope(): - l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + l = core.Dense(1, use_bias=False, + name='gpu_%d' % d.extended._num_gpus_per_worker) def loss_fn(x): y = array_ops.reshape(l(x), []) - constant_op.constant(1.) @@ -127,8 +132,8 @@ class CollectiveAllReduceStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -136,14 +141,13 @@ class CollectiveAllReduceStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) for i in range(10): - b, a = sess.run((before_out, after_out), options=self._run_options) + b, a = sess.run((before_out, after_out)) if i == 0: before, = b after, = a @@ -222,26 +226,54 @@ class CollectiveAllReduceStrategyTestBase( return array_ops.identity(x) x = distribution.call_for_each_replica(model_fn) - reduced_x = distribution.unwrap( - distribution.reduce( - variable_scope.VariableAggregation.MEAN, x, - destinations='/cpu:0'))[0] + reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) - x_value, reduced_x_value = sess.run([x, reduced_x], - options=self._run_options) + x_value, reduced_x_value = sess.run([x, reduced_x]) self.assertTrue( np.allclose(x_value, reduced_x_value, atol=1e-5), msg=('x_value = %r, reduced_x_value = %r' % (x_value, reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class DistributedCollectiveAllReduceStrategyTest( - CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -269,7 +301,7 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, @@ -279,10 +311,56 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + # TODO(yuefengz): Update how we use num_gpus and required_gpus + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMakeInputFnIterator(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + # We use CPU as the device when num_gpus = 0 + devices_per_worker = max(1, num_gpus) + expected_values = [[i+j for j in range(devices_per_worker)] + for i in range(0, 100, devices_per_worker)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=3*devices_per_worker, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + def testUpdateConfigProto(self): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + rewrite_options = config_proto.graph_options.rewrite_options + rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') + + new_config = distribution.update_config_proto(config_proto) + + # Verify group leader + self.assertEqual('/job:worker/replica:0/task:0', + new_config.experimental.collective_group_leader) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1'], new_config.device_filters) + + # Verify rewrite options. + new_rewrite_options = new_config.graph_options.rewrite_options + self.assertEqual(rewriter_config_pb2.RewriterConfig.ON, + new_rewrite_options.scoped_allocator_optimization) + self.assertEqual(['CollectiveReduce'], + new_rewrite_options.scoped_allocator_opts.enable_op) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -293,10 +371,6 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0, has_chief=True) - def setUp(self): - super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp() - self._run_options.experimental.collective_graph_key = 7 - @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testMinimizeLossGraph(self, num_gpus): @@ -323,20 +397,36 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_minimize_loss_graph(None, None, num_gpus) def testComplexModel(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) + def testMakeInputFnIterator(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index a51371654031e32d084e2b0e8ae345bb2c166ae8..365ce5cdec79f1914f0c9ccdf59a7dc59e6f819e 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -53,11 +53,11 @@ from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop from tensorflow.python.util import tf_inspect @@ -168,6 +168,8 @@ def _augment_with_special_arguments(test_method): if GPU_TEST: self.skipTest("Test that doesn't require GPUs.") elif context.num_gpus() < required_gpus: + # TODO(priyag): Consider allowing tests in graph mode using soft + # placement. self.skipTest( "{} GPUs are not available for this test. {} GPUs are available". format(required_gpus, context.num_gpus())) @@ -190,7 +192,7 @@ def _augment_with_special_arguments(test_method): kwargs_to_pass[arg] = kwargs[arg] if mode == "eager": - with ops.Graph().as_default(), context.eager_mode(): + with context.eager_mode(): if distribution: kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) @@ -335,6 +337,13 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) +mirrored_strategy_with_one_cpu = NamedDistribution( + "Mirrored1CPU", + lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) +mirrored_strategy_with_one_gpu = NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), + required_gpus=1) mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), @@ -343,6 +352,21 @@ mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) +core_mirrored_strategy_with_one_cpu = NamedDistribution( + "CoreMirrored1CPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/cpu:0"])) +core_mirrored_strategy_with_one_gpu = NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "CoreMirroredCPUAndGPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/cpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_two_gpus = NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2) gradient_descent_optimizer_v1_fn = NamedObject( @@ -373,8 +397,11 @@ def distributions_and_v1_optimizers(): """A common set of combination with DistributionStrategies and Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v1) @@ -383,7 +410,10 @@ def distributions_and_v2_optimizers(): """DistributionStrategies and V2 Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py similarity index 79% rename from tensorflow/contrib/distribute/python/cross_tower_ops_test.py rename to tensorflow/contrib/distribute/python/cross_device_ops_test.py index 3e274ba67ca6709a14f5391968f28b721e46b8a6..d6e9521c1c1115ffdbdcf375ad4017bacb962832 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -24,24 +24,24 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import device_util def _make_per_replica(values, devices, regroup=False): - devices = cross_tower_ops_lib.get_devices_from(devices) + devices = cross_device_ops_lib.get_devices_from(devices) assert len(values) == len(devices) # We simulate the result of regroup called on PerReplica which strips the @@ -66,7 +66,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib.get_devices_from(devices) + devices = cross_device_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) @@ -118,8 +118,8 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - def _testReductionAndBroadcast(self, cross_tower_ops, distribution): - devices = distribution.worker_devices + def _testReductionAndBroadcast(self, cross_device_ops, distribution): + devices = distribution.extended.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] per_replica = _make_per_replica(values, devices) @@ -132,35 +132,33 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): destination_mirrored = _fake_mirrored(1., devices) destination_different = _fake_mirrored(1., _cpu_device) destination_str = _cpu_device - destination_list = devices all_destinations = [ destination_mirrored, destination_different, destination_str, - destination_list ] # test reduce() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_replica, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) @@ -168,16 +166,16 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce( - vs.VariableAggregation.MEAN, + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ]) self._assert_values_equal( - cross_tower_ops.batch_reduce( - vs.VariableAggregation.SUM, + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), @@ -187,7 +185,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test broadcast() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + cross_device_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) @@ -196,62 +194,65 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): # combinations module so that we can pass in devices instead of a distribution # strategy. reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "DefaultReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], mode=["graph", "eager"]) allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "AllReduce", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), combinations.NamedObject( "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -259,16 +260,16 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -280,8 +281,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce( - per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + result = cross_device_ops_lib._simple_reduce( + per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -294,19 +295,19 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): @combinations.generate( combinations.combine( - cross_tower_ops_instance=[ + cross_device_ops_instance=[ combinations.NamedObject( "ReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "AllReduceCrossDeviceOps", - cross_tower_ops_lib.AllReduceCrossDeviceOps()) + cross_device_ops_lib.AllReduceCrossDeviceOps()) ], - aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], batch_reduce=[True, False], mode=["graph", "eager"], required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op, batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] @@ -316,20 +317,20 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce( - aggregation, [(per_replica, devices)]) + result = cross_device_ops_instance.batch_reduce( + reduce_op, [(per_replica, per_replica)]) else: - result = cross_tower_ops_instance.reduce( - aggregation, per_replica, devices) + result = cross_device_ops_instance.reduce( + reduce_op, per_replica, per_replica) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if aggregation == vs.VariableAggregation.SUM: + if reduce_op == reduce_util.ReduceOp.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert aggregation == vs.VariableAggregation.MEAN + assert reduce_op == reduce_util.ReduceOp.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -356,49 +357,63 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] multi_worker_allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "MultiWorkerAllReduce", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), combinations.NamedObject( "MultiWorkerAllReducePack", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), combinations.NamedObject( "MultiWorkerAllReduceAggregation", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), combinations.NamedObject( "MultiWorkerAllReduceMultipleSpecs", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, [("pscpu/pscpu", 2, 100), ("xring", 2, -1)], 0, 0, 0)), ], distribution=[ combinations.NamedDistribution( "MirroredCPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0), required_gpus=0), combinations.NamedDistribution( "Mirrored1GPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1), required_gpus=1), combinations.NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2), + required_gpus=2), + # pylint: disable=g-long-lambda + combinations.NamedDistribution( + "CoreMirroredCPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), + required_gpus=0), + combinations.NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]), required_gpus=2), ], mode=["graph"]) @combinations.generate(multi_worker_allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): distribution.configure(cluster_spec={ "worker": ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] }) with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) class MultiWorkerCollectiveAllReduceTest( @@ -419,7 +434,7 @@ class MultiWorkerCollectiveAllReduceTest( MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, instance_key_start=num_gpus * 100 + @@ -427,7 +442,7 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] @@ -435,7 +450,7 @@ class MultiWorkerCollectiveAllReduceTest( devices = ["/device:CPU:0"] return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ @@ -491,37 +506,35 @@ class MultiWorkerCollectiveAllReduceTest( destination_mirrored = _fake_mirrored(1., devices) destination_different = _fake_mirrored(1., _cpu_device) destination_str = _cpu_device - destination_list = devices all_destinations = [ - destination_different, destination_mirrored, destination_str, - destination_list + destination_different, destination_mirrored, destination_str ] # test reduce() for destinations in all_destinations: self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, + reduce_util.ReduceOp.MEAN, per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, + reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), @@ -530,7 +543,7 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, [(per_replica, d1), (per_replica_2, d2)]), [ @@ -538,7 +551,7 @@ class MultiWorkerCollectiveAllReduceTest( _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, [(per_replica, d1), (per_replica_2, d2)]), [ diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py similarity index 76% rename from tensorflow/contrib/distribute/python/cross_tower_utils_test.py rename to tensorflow/contrib/distribute/python/cross_device_utils_test.py index a991156ca87fb666f9e47462ccf2bbbe305fe925..2303a31677afbd12a0b8e7eea3ecf7c7736c46ad 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for cross_tower_utils.""" +"""Tests for cross_device_utils.""" from __future__ import absolute_import from __future__ import division @@ -21,14 +21,14 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops -from tensorflow.python.training import device_util class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): @@ -43,7 +43,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self._assert_values_equal(total, result) @test_util.run_in_graph_and_eager_modes @@ -53,7 +53,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(total, result) @@ -62,7 +62,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result) @test_util.run_in_graph_and_eager_modes @@ -71,7 +71,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result) @@ -79,7 +79,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + self.assertTrue(cross_device_utils.contains_indexed_slices(t)) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_List(self): @@ -87,7 +87,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1])) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_Tuple(self): @@ -95,7 +95,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1))) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerReplica(self): @@ -104,18 +104,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerReplicaMapOutput(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_replica = value_lib.PerReplica({ - "/gpu:0": value_lib.MapOutput([t0]), - "/cpu:0": value_lib.MapOutput([t1])}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica)) + self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( mode=["graph", "eager"], @@ -124,7 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): with ops.device("/cpu:0"): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self._assert_values_equal(t, result) @@ -139,7 +128,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self.assertIsInstance(result, ops.IndexedSlices) diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index a1355c0b09e51c18cc4f8967dfc2c472d63593b9..e17085628ba6d1dfc79839fd824801723f07a518 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -34,7 +34,7 @@ from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -63,7 +63,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -75,12 +77,12 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=True) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 8f82b4c92aa4305af121855972df4947c963850d..b369a7fefe6f35cf5a9b64451419cf4f72a99471 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -24,7 +24,6 @@ import json import os import sys import tempfile -import threading from absl.testing import parameterized import numpy as np @@ -45,11 +44,13 @@ from tensorflow.python.estimator import training as estimator_training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import session_manager + BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -68,57 +69,19 @@ PS = dc._TaskType.PS original_run_std_server = dc._run_std_server -class MockOsEnv(dict): - - def __init__(self, *args): - self._thread_local = threading.local() - super(MockOsEnv, self).__init__(*args) - - def get(self, key, default): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.get(self._thread_local.dict, key, default) - else: - return dict.get(self, key, default) - - def __getitem__(self, key): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__getitem__(self._thread_local.dict, key) - else: - return dict.__getitem__(self, key) - - def __setitem__(self, key, val): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__setitem__(self._thread_local.dict, key, val) - else: - return dict.__setitem__(self, key, val) - - -class DistributeCoordinatorIntegrationTest(test.TestCase, - parameterized.TestCase): +class DistributeCoordinatorIntegrationTest( + multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" + super(DistributeCoordinatorIntegrationTest, cls).setUpClass() cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) def setUp(self): self._model_dir = tempfile.mkdtemp() - self._mock_os_env = MockOsEnv() - self._mock_context = test.mock.patch.object(os, "environ", - self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() - self._mock_context.__enter__() - - def tearDown(self): - self._mock_context.__exit__(None, None, None) - super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -141,8 +104,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, def _extract_loss_and_global_step(self, event_folder): """Returns the loss and global step in last event.""" event_paths = glob.glob(os.path.join(event_folder, "events*")) - self.assertGreater(len(event_paths), 0, - msg="Event file not found in dir %s" % event_folder) + self.assertNotEmpty( + event_paths, msg="Event file not found in dir %s" % event_folder) loss = None global_step_count = None @@ -202,10 +165,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={"x": DATA}, y=DATA, - batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync, shuffle=True) if eval_distribute: - eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync else: eval_batch_size = BATCH_SIZE eval_input_fn = self.dataset_input_fn( @@ -285,27 +248,34 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + def _get_strategy_object(self, strategy_cls): + if strategy_cls == mirrored_strategy.CoreMirroredStrategy: + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + @combinations.generate( combinations.combine( mode=["graph"], train_distribute_cls=[ collective_all_reduce_strategy.CollectiveAllReduceStrategy, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy ], eval_distribute_cls=[ - None, mirrored_strategy.MirroredStrategy, + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -322,20 +292,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, mode=["graph"], train_distribute_cls=[ mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], required_gpus=[0, 1])) def test_estimator_standalone_client(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -355,47 +325,15 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, self._barrier.wait() return ret - def _task_thread(self, train_distribute, eval_distribute, tf_config): - os.environ["TF_CONFIG"] = json.dumps(tf_config) + def _independent_worker_fn( + self, + train_distribute, + eval_distribute, + ): with test.mock.patch.object(dc, "_run_std_server", self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) - def _run_task_in_thread(self, cluster_spec, task_type, task_id, - train_distribute, eval_distribute): - if task_type: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - else: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - t = threading.Thread( - target=self._task_thread, - args=(train_distribute, eval_distribute, tf_config)) - t.start() - return t - - def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, - eval_distribute): - threads = {} - for task_type in cluster_spec.keys(): - threads[task_type] = [] - for task_id in range(len(cluster_spec[task_type])): - t = self._run_task_in_thread(cluster_spec, task_type, task_id, - train_distribute, eval_distribute) - threads[task_type].append(t) - return threads - @combinations.generate( combinations.combine( mode=["graph"], @@ -405,21 +343,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) - if (context.num_gpus() < 2 and eval_distribute_cls == collective_all_reduce_strategy.CollectiveAllReduceStrategy): self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") + train_distribute = self._get_strategy_object(train_distribute_cls) + if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -435,8 +372,9 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) for task_type, ts in threads.items(): if task_type == PS: continue @@ -449,17 +387,22 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @combinations.generate( combinations.combine( mode=["graph"], - train_distribute_cls=[mirrored_strategy.MirroredStrategy], - eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], + eval_distribute_cls=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -467,8 +410,9 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) threads[WORKER][0].join() threads[EVALUATOR][0].join() @@ -506,7 +450,8 @@ class RunConfigTest(test.TestCase): "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -529,10 +474,12 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + eval_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -545,26 +492,27 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): with self.assertRaises(ValueError): run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy())) + train_distribute=mirrored_strategy.CoreMirroredStrategy())) with self.assertRaises(ValueError): run_config_lib.RunConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(), + eval_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy())) + eval_distribute=mirrored_strategy.CoreMirroredStrategy())) def test_init_run_config_none_distribute_coordinator_mode(self): # We don't use distribute coordinator for local training. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) dc_training.init_run_config(config, {}) self.assertIsNone(config._distribute_coordinator_mode) @@ -572,7 +520,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertIsNone(config._distribute_coordinator_mode) # When `train_distribute` is not specified, don't use distribute @@ -588,7 +536,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertEqual(config._distribute_coordinator_mode, dc.CoordinatorMode.INDEPENDENT_WORKER) @@ -597,7 +545,7 @@ class RunConfigTest(test.TestCase): # `experimental.remote_cluster` is set use distribute coordinator with # STANDALONE_CLIENT mode. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( remote_cluster={"chief": ["fake_worker"]})) self.assertEqual(config._distribute_coordinator_mode, @@ -605,5 +553,15 @@ class RunConfigTest(test.TestCase): if __name__ == "__main__": + # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. + orig_init = session_manager.SessionManager.__init__ + + def new_init(*args, **kwargs): + kwargs.pop("recovery_wait_secs", None) + kwargs["recovery_wait_secs"] = 0.5 + orig_init(*args, **kwargs) + + session_manager.SessionManager.__init__ = new_init + with test.mock.patch.object(sys, "exit", os._exit): test.main() diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index 0fd3acd045170c04ebdaa9c84d0cb7267a4bc68a..60fda996642464135fe1fb8c314bcf7f04d19362 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -20,6 +20,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.keras.optimizer_v2 import rmsprop + + NUM_CLASSES = 10 @@ -102,18 +106,23 @@ def main(_): # Build the train and eval datasets from the MNIST data. Also return the # input shape which is constructed based on the `image_data_format` # i.e channels_first or channels_last. + tf.enable_eager_execution() + train_ds, eval_ds, input_shape = get_input_datasets() model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. - strategy = tf.contrib.distribute.MirroredStrategy() + # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. + strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) + + optimizer = rmsprop.RMSProp(learning_rate=0.001) # Compile the model by passing the distribution strategy object to the # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed # based on the strategy instantiated. model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001), + optimizer=optimizer, metrics=['accuracy'], distribute=strategy) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 46a1cf41c55b371e87979ca625765e0531ac188b..6dfd85bcc4f3784e2744fd876a7190cc9581d96a 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -25,18 +25,23 @@ import numpy as np import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.estimator import run_config from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -64,7 +69,9 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -76,11 +83,11 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -136,44 +143,51 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): shutil.rmtree(self._model_dir) -class MirroredStrategyOptimizerV2Test(test.TestCase): +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model - def testKerasOptimizerWithUnequalInput(self): - if context.num_gpus() < 1: - self.skipTest('Not enough GPUs.') - def create_fn(device_id): +class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testKerasOptimizerWithUnequalInput(self, distribution): + def create_fn(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = (device_id + 1) * var + loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iteration) + return (var, m, v, train_op, optimizer.iterations) devices = ['/device:GPU:0', '/device:CPU:0'] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - (var, m, v, op, counter) = dist.call_for_each_replica( - create_fn, args=[dist.worker_device_index]) + with distribution.scope(): + (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([0, 0, 0], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) - train_op = dist.unwrap(op) + train_op = distribution.unwrap(op) self.evaluate(train_op) # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 m_val = [1.2, 1.2, 1.2] @@ -181,7 +195,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 @@ -189,7 +203,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) @@ -198,12 +212,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([1, 1, 1], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -214,7 +228,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 @@ -222,16 +236,50 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) self.assertAllClose([2, 2, 2], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): + + with self.cached_session(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit( + inputs, + targets, + epochs=1, + batch_size=2, + verbose=0, + validation_data=(inputs, targets)) + model.evaluate(inputs, targets) + model.predict(inputs) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 0db5844e4c40e84c635b063523b95226241d07fb..683cc89bfbae9c877ea6794d311ffc00c96c6937 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -24,9 +24,10 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op @@ -35,14 +36,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile -from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop - _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) @@ -165,7 +165,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -212,13 +214,19 @@ def multi_input_output_model(): return model -def get_correctness_test_inputs(use_numpy, with_distribution, +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict): """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 global_batch_size = 64 batch_size = global_batch_size # TODO(b/118776054): Use global batch size for Keras/DS support. - if with_distribution: + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: batch_size //= with_distribution.num_replicas_in_sync if use_numpy: @@ -226,19 +234,20 @@ def get_correctness_test_inputs(use_numpy, with_distribution, 'batch_size': batch_size, 'x': x_train, 'y': y_train, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, } - eval_inputs = { - 'batch_size': batch_size, - 'x': x_train, - 'y': y_train, - } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } predict_inputs = { - # TODO(b/119318587): We should not require batch_size when distribution - # is enabled. - 'batch_size': (len(x_predict) // with_distribution.num_replicas_in_sync - if with_distribution else None), 'x': np.array(x_predict, dtype=np.float32), } else: @@ -246,30 +255,39 @@ def get_correctness_test_inputs(use_numpy, with_distribution, # keras.fit/evaluate/predict. The batch size is part of the dataset. train_dataset = dataset_ops.Dataset.from_tensor_slices( (x_train, y_train)) - x = batch_wrapper(train_dataset, batch_size, with_distribution) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) training_inputs = { 'batch_size': None, 'x': x, 'y': None, - 'epochs': 1, + 'epochs': training_epochs, 'shuffle': False, 'steps_per_epoch': len(x_train) // global_batch_size, } - eval_inputs = { - 'batch_size': None, - 'x': x, - 'y': None, - 'steps': 20, - } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) - if with_distribution: + if use_per_core_batch_size: predict_batch_size //= with_distribution.num_replicas_in_sync predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) predict_dataset = batch_wrapper(predict_dataset, predict_batch_size, with_distribution) predict_inputs = { - 'batch_size': None, 'steps': 1, 'x': predict_dataset, } @@ -277,47 +295,71 @@ def get_correctness_test_inputs(use_numpy, with_distribution, return training_inputs, eval_inputs, predict_inputs -strategies = [combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.tpu_strategy, # steps_per_run=2 - combinations.tpu_strategy_one_step] +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] def strategy_minus_tpu_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - mode=['graph']) + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) -def strategy_combinations(): +def tpu_strategy_combinations(): return combinations.combine( - distribution=strategies, + distribution=tpu_strategies, mode=['graph']) -def strategy_and_optimizer_combinations(): - return combinations.combine( - distribution=strategies, - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn], - mode=['graph']) +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() -def strategy_and_inputs(): +# TODO(priyag): Add v2 optimizers here. +def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine( + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): return combinations.combine( - distribution=strategies, - use_numpy=[True, False], + distribution=strategies_minus_tpu + tpu_strategies, mode=['graph']) -class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -325,17 +367,18 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) - self._dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() if os.path.isdir(self._base_dir): gfile.DeleteRecursively(self._base_dir) - def test_train_functional_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_functional_with_distribution_strategy(self, distribution): keras_model = simple_functional_model() keras_model.compile( loss='categorical_crossentropy', @@ -343,8 +386,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist, - eval_distribute=dist) + train_distribute=distribution, + eval_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -358,9 +401,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_train_sequential_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_sequential_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -368,7 +414,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -382,7 +428,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() def train_input_fn(): @@ -412,14 +463,14 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): output_dict)).batch(16) self.do_test_multi_inputs_multi_outputs_with_input_fn( - train_input_fn, eval_input_fn) + distribution, train_input_fn, eval_input_fn) - def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, - eval_input_fn): + def do_test_multi_inputs_multi_outputs_with_input_fn( + self, distribution, train_input_fn, eval_input_fn): config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=self._dist) + train_distribute=distribution) with self.cached_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) @@ -429,9 +480,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) self.assertLess(eval_results['loss'], baseline_eval_results['loss']) - def test_keras_optimizer_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -439,7 +493,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator(keras_model=keras_model, config=config) @@ -455,7 +509,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): x = np.asarray(np.random.random((64, 3)), dtype=np.float32) @@ -464,84 +518,135 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # Verify that the numpy value is copied to the variable. self.assertAllEqual(x, val) - def test_calculating_batch_params(self): - # This verifies that we calculate the number of steps when the batch size - # is specified. + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_no_batch_size(self, distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - # The number of replicas is equal to 3. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0', - '/device:GPU:1']) - - with self.assertRaisesRegexp(ValueError, 'Please specify a batch_size ' - 'that is smaller than'): - # The batch size(128) is larger than the number of input - # samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 128, - strategy) - - with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' - 'of replicas'): - # The batch size(32) * num_replicas_in_sync(3) is 96 which is greater - # than the number of input samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - - # The number of replicas now is equal to 2. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - # 32 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(32) which is 2. The number of steps(1) is the ratio of - # number of batches(2) to the number of replicas(2). + # Input samples of different sizes + input_20_samples = np.zeros((20, 3), dtype=np.float32) + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Default global batch size 32 for input with 64 samples run in 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # Computed global batch size 20 is lower than 32 if we pass less samples. + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_20_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 20 // replica_scale_factor) self.assertEqual(steps, 1) - # 16 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 16, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(16) which is 4. The number of steps(2) is the ratio of - # number of batches(4) to the number of replicas(2). + # Default global batch size 32 cannot be used with 63 samples. + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=None, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_no_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed global batch size is correct for number of specified 1 step + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=1, batch_size=None) + self.assertEqual(batch_size, 64 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Computed global batch size is correct for number of specified 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=2, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) self.assertEqual(steps, 2) - def test_calculating_batch_size(self): + # All samples can not be consumed in specified number of steps + with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=2, batch_size=None) + + # This cases is different for different strategies due to the + # difference in supported batch size being global or per-replica. + if replica_scale_factor == 1: + # Computed global batch size is correct even if not sharadable + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=3, batch_size=None) + self.assertEqual(batch_size, 21) + self.assertEqual(steps, 3) + else: + # Computed global batch size can not be sharded across replicas + with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' + 'across the sync replicas'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=1, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_with_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=16) + self.assertEqual(batch_size, 16) + self.assertEqual(steps, 4 // replica_scale_factor) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=32) + self.assertEqual(batch_size, 32) + self.assertEqual(steps, 2 // replica_scale_factor) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=20) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=3) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_with_batch_size(self, + distribution): + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - strategy._require_static_shapes = True - - model.compile(optimizer, loss, distribute=strategy) - iterator = model._distribution_standardize_user_data(inputs, - targets, - batch_size=None, - check_steps=True, - steps_name='steps', - steps=3) - - # The global batch size(21) across all replicas is the ratio of the input - # samples(64) to the steps(3). - # The batch size(10) per device is the ratio of the global batch size(21) - # to the number of replicas(2). - # The global batch size and batch size are rounded integer values. - self.assertEqual(10, distributed_training_utils.get_batch_dimension( - iterator._iterator)) - - @combinations.generate(strategy_combinations()) + # No change to steps and batch size if both specified and feasible + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=5, batch_size=3) + self.assertEqual(batch_size, 3) + self.assertEqual(steps, 5) + + # Number of samples is less than global batch size * steps + with self.assertRaisesRegexp(ValueError, 'less than samples required'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=10, batch_size=13) + + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): model = get_model() @@ -572,7 +677,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -606,21 +711,22 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) def test_numpy_with_sample_weights(self, distribution): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' model.compile(optimizer, loss, distribute=distribution) - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): model = multi_input_output_model() @@ -638,7 +744,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # `predict` a list that is equal in length to the number of model outputs. # In this test our model has two outputs and each element of `outs` # corresponds to all the samples of one of the model outputs. - self.assertEqual(2, len(outs)) + self.assertLen(outs, 2) # Each of the output samples have a dimension of 7. We should process all # the available input samples(6). self.assertAllEqual([6, 7], outs[0].shape) @@ -648,7 +754,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -667,7 +773,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, validation_data=dataset, validation_steps=2) model.predict(get_predict_dataset(distribution), steps=2) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): user_controlled_model = get_model() @@ -710,16 +816,20 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not # tuples or dict. - def test_fit_with_tuple_and_dict_dataset_inputs(self): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): model = multi_input_output_model() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -743,7 +853,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -792,35 +902,48 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.evaluate(dataset, steps=2, verbose=1) model.predict(dataset, steps=2) - def test_dataset_input_shape_validation(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) + model.compile(optimizer, loss, distribute=distribution) - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) + dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - # Wrong input shape - inputs = np.zeros((10, 5), dtype=np.float32) + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, - 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( @@ -842,7 +965,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - def test_learning_phase_value(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. @@ -856,15 +984,17 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.005) loss = 'mse' metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync inputs = np.ones((10, 1), dtype=np.float32) targets = np.ones((10, 1), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat().batch(8) + dataset = dataset.repeat().batch(batch_size) hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) @@ -875,24 +1005,51 @@ class TestDistributionStrategyWithDatasets(test.TestCase, inputs = np.ones((10, 1), dtype=np.float32) predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - predict_dataset = predict_dataset.repeat().batch(5) + + predict_dataset = predict_dataset.repeat().batch(batch_size) output = model.predict(predict_dataset, steps=10) - # `predict` runs for 10 steps and in each step you process 100 samples. - ref_output = np.ones((100, 1), dtype=np.float32) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) + + def schedule(_): + return 0.001 + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap(model._grouped_model) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) + class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - def test_validating_dataset_input_tensors_with_shape_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. @@ -901,17 +1058,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. @@ -920,21 +1081,23 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_unsupported_features(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) # Test with validation split with self.assertRaisesRegexp( @@ -969,45 +1132,48 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) - def test_calling_with_unsupported_predefined_callbacks(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) def schedule(_): return 0.001 with self.assertRaisesRegexp(ValueError, - 'LearningRateScheduler callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) with self.assertRaisesRegexp(ValueError, - 'ReduceLROnPlateau callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) -class TestDistributionStrategyWithLossMasking(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. - def test_masking(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) @@ -1016,12 +1182,9 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -1033,7 +1196,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -1065,7 +1228,7 @@ class TestDistributionStrategyWithNormalizationLayer( class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -1088,21 +1251,63 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_replicas_in_sync + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) - @combinations.generate(strategy_and_inputs()) - def test_correctness(self, distribution, use_numpy): + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): with self.cached_session(): - tolerance = 1e-5 + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): - if isinstance(distribution, mirrored_strategy.MirroredStrategy): - # TODO(b/119257215): use the default one once the flakyness is fixed. - tolerance = 1e-4 + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy)): + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } keras.backend.set_image_data_format('channels_last') np.random.seed(_RANDOM_SEED) @@ -1123,49 +1328,75 @@ class TestDistributionStrategyCorrectness(test.TestCase, # This is used to initialize the model for both the distribution and # non-distribution run. In addition, we add few non-linear layers to make # it non-trivial. - model = keras.Sequential() - model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(10, activation='relu')) - model.add(keras.layers.Dense(1)) + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() initial_weights = model.get_weights() + del model # avoid accident usage. - def fit_and_predict(with_distribution=None): + def fit_eval_and_predict(with_distribution=None): + model = _create_model() # We have initialized the model to the same weight for the distribution # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], distribute=with_distribution) training_inputs, eval_inputs, predict_inputs = ( - get_correctness_test_inputs(use_numpy, with_distribution, + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, x_train, y_train, x_predict)) - model.fit(**training_inputs) - eval_result = model.evaluate(**eval_inputs) - weights = model.get_weights() - predict_result = model.predict(**predict_inputs) + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() - return weights, eval_result, predict_result + return result - wts_with_ds, eval_with_ds, predict_with_ds = fit_and_predict( - with_distribution=distribution) - wts_without_ds, eval_without_ds, predict_without_ds = fit_and_predict( - with_distribution=None) + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) - # Verify that the weights, eval results, predict outputs are the same - # within some limits of tolerance. - self.assertAllClose( - wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) - self.assertAllClose( - predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + tolerance = tol_table.get(key, default_tolerance) -# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index c28ab416518799e239bff43def75e00b7c22ee73..8ac659abe96370b751ed1556cc699fe20788a0fd 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -72,14 +72,14 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() -# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using -# ReplicaLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph"]) @@ -100,18 +100,19 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, args=[inputs]) + metric_fn, args=inputs) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) ctx = distribution.run_steps_on_dataset( - step_fn, iterator, iterations=distribution.steps_per_run) + step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] # In each run, we run multiple steps, and each steps consumes as many # batches as number of replicas. batches_per_update = ( - distribution.num_replicas_in_sync * distribution.steps_per_run) + distribution.num_replicas_in_sync * + distribution.extended.steps_per_run) else: value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index c6562463edbf8e03d5771a5147dc227ddf438c40..f09483cb56b66fd4720ee71085203c14f1ccadc3 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -63,7 +64,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( distribution.call_for_each_replica(model_fn, args=inputs)) @@ -157,7 +158,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( distribution.call_for_each_replica(model_fn, args=inputs)) @@ -226,12 +227,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -285,7 +286,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + @@ -321,10 +324,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() - def step_fn(ctx, x, y): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica(model_fn, args=(x, y))) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -341,7 +344,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): run_step() v = all_vars[0] - self.assertTrue(all([v is vi for vi in all_vars[1:]])) + self.assertTrue(all(v is vi for vi in all_vars[1:])) weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w @@ -402,21 +405,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( - name="replica_loss_agg", + name="replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) - def step_fn(output_context, *inputs): + def step_fn(output_context, inputs): (train_op, loss) = distribution.call_for_each_replica( model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( - name="cross_replica_loss_agg", + name="cross_replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( - name="cross_replica_loss_noagg", + name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) @@ -424,36 +427,36 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): initial_loss = lambda: constant_op.constant(1e7) - # Initial values corresponding to aggregated losses are just single - # tensors. But for non aggregated losses, we need to have initial + # Initial values corresponding to reduced losses are just single + # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `broadcast` followed by `unwrap` # gives us the desired initial value structure. initial_loop_values = { - "replica_loss_agg": initial_loss(), - "cross_replica_loss_agg": initial_loss(), - "cross_replica_loss_noagg": + "replica_loss_reduced": initial_loss(), + "cross_replica_loss_reduced": initial_loss(), + "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } ctx = distribution.run_steps_on_dataset( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"], - aggregated=False, distribution=distribution) - return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"]) + loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], + reduced=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) self.evaluate(distribution.initialize()) if not context.executing_eagerly(): @@ -478,18 +481,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing) - def _verify_loss_output(self, initial_loss, loss_output, aggregated, + def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): - if not aggregated: - self.assertEqual(distribution.num_replicas_in_sync, - len(distribution.unwrap(loss_output))) - loss_output = distribution.reduce( - aggregation=variables_lib.VariableAggregation.MEAN, - value=loss_output, destinations="/device:CPU:0") - - unwrapped_output = distribution.unwrap(loss_output) - self.assertEqual(1, len(unwrapped_output)) - loss_tensor = unwrapped_output[0] + if not reduced: + self.assertLen(distribution.unwrap(loss_output), + distribution.num_replicas_in_sync) + loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output) + else: + unwrapped_output = distribution.unwrap(loss_output) + self.assertLen(unwrapped_output, 1) + loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.shape, loss_tensor.shape) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index e90c510aadb40555cacf60bcff5516e87e06b728..20f1a08d4261b931a9353738147fba7d7dff9225 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -12,293 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class MirroredStrategy implementing DistributionStrategy.""" +"""Contrib version of MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -from functools import partial -import threading +import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import shared_variable_creator -from tensorflow.contrib.distribute.python import values -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import coordinator -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.util import nest +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import values -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -@contextlib.contextmanager -def _enter_graph(g): - if context.executing_eagerly(): - with g.as_default(), context.eager_mode(): - yield - else: - with g.as_default(): - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) - return cpu_device.to_string() - - -class _RequestedStop(Exception): - pass - - -# _call_for_each_replica and _reduce_non_distributed_value are not members of -# MirroredStrategy so that they are generally not allowed to use anything -# specific to MirroredStrategy and thus can be shared with other distribution -# strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_replica. -def _call_for_each_replica(distribution, fn, args, kwargs): - """Run `fn` in separate threads, once per replica/worker device. - - Args: - distribution: the DistributionStrategy object. - fn: function to run (will be run once per device, each in its own thread). - args: positional arguments for `fn` - kwargs: keyword arguments for `fn`. - - Returns: - Merged return value of `fn` across all replicas. - - Raises: - RuntimeError: If fn() calls get_replica_context().merge_call() a different - number of times from the available devices. - """ - # TODO(josh11b): Add this option once we add synchronization to variable - # creation. Until then, this is pretty unsafe to use. - run_concurrently = False - if not context.executing_eagerly(): - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(distribution.worker_devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredReplicaThread( # pylint: disable=protected-access - distribution, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredReplicaThread - # (`MRT`) threads. The execution waits until - # `MRT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_replica_context().merge_call()` is called. If `fn` is - # complete, then `MRT.done` is set to True. Otherwise, arguments - # of `get_replica_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_replica_context().merge_call` are then set to `MRT.merge_result`. - # Each such `get_replica_context().merge_call` call returns the - # `MRT.merge_result` for that thread when `MRT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some replicas made a different number of " - "replica_context().merge_call() calls.") - # get_replica_context().merge_call() case - merge_args = values.regroup({t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MRT when we call merge_fn - # to ensure that if we have opened a name scope in the MRT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MRT and assume it is - # the same for all other MRTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) - - -def _reduce_non_distributed_value(distribution, aggregation, value, - destinations): - """Reduce a non-DistributedValue `value` to `destinations`.""" - if isinstance(value, values.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " - "`_reduce_non_distributed_value`, which is not allowed.") - - # If the same value is present on all replicas then the PerReplica value will - # be a single value. We also handle the case when `value` is a single value - # and equal to 0. - if value == 0: - return 0 - # If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this - # essentially means that the same value should be on all destinations. - if aggregation in ( - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): - return value - - cross_tower_ops_lib.validate_destinations(destinations) - # We do not support an aggregation type of SUM if the value is the same across - # all replicas. We call this as part of assign functions for MirroredVariables - # and summing up identical values across replicas is not clearly defined. - if (len(distribution.worker_devices) != 1 or - not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given aggregation %s." % (value, aggregation)) - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - - -def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Get synchronization value - synchronization = kwargs.get("synchronization", - variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are replica local. - is_replica_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_replica_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in ( - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - - if is_replica_local: - result = values.ReplicaLocalVariable( - index, index[devices[0]], aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - if v in l: - l.remove(v) - g.add_to_collections(collections, result) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) - - return result +# pylint: disable=protected-access,invalid-name +_call_for_each_replica = mirrored_strategy._call_for_each_replica +_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value +_create_mirrored_variable = mirrored_strategy._create_mirrored_variable +all_local_devices = mirrored_strategy.all_local_devices +CoreMirroredStrategy = mirrored_strategy.MirroredStrategy +CoreMirroredExtended = mirrored_strategy.MirroredExtended +# pylint: enable=protected-access,invalid-name class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices and machines. + *** contrib version *** + This strategy uses one replica per device and sync replication for its multi-GPU version. @@ -353,483 +95,66 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_device_ops=None, auto_shard_dataset=False, cross_tower_ops=None): - super(MirroredStrategy, self).__init__() - assert not (cross_device_ops and cross_tower_ops) - self._cross_tower_ops = cross_device_ops or cross_tower_ops - self._auto_shard_dataset = auto_shard_dataset - # Remember num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") - if num_gpus is not None: - self._num_gpus = num_gpus - else: - self._num_gpus = num_gpus_per_worker - - self._initialize_local(self._num_gpus, devices) - - def _initialize_local(self, num_gpus, devices): - """Initializes the object for local training.""" - self._cluster_spec = None - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: - if num_gpus is None: - num_gpus = context.num_gpus() - if num_gpus == 0: - devices = ["/device:CPU:0"] - else: - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") - self._num_gpus = num_gpus - # TODO(yuefengz): consider setting the default device. - - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerReplica( - {d: i for i, d in enumerate(devices)}) - - def _initialize_multi_worker(self, num_gpus, cluster_spec): - """Initializes the object for multi-worker training.""" - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._cluster_spec = cluster_spec - - self._workers = [] - for job in ["chief", "worker"]: - for task in range(len(cluster_spec.as_dict().get(job, []))): - self._workers.append("/job:%s/task:%d" % (job, task)) - if num_gpus is None: - raise ValueError("`num_gpus` is required if `cluster_spec` is given.") - if num_gpus > 0: - self._worker_devices = [ - (worker, [ - device_util.canonicalize(worker + "/device:GPU:%d" % gpu) - for gpu in range(num_gpus) - ]) for worker in self._workers - ] - else: - self._worker_devices = [ - (worker, [device_util.canonicalize(worker, "/device:CPU:0")]) - for worker in self._workers - ] - - devices = nest.flatten([l for _, l in self._worker_devices]) + num_gpus = num_gpus_per_worker + extended = MirroredExtended(self, devices, num_gpus, + cross_device_ops or cross_tower_ops, + auto_shard_dataset) + super(MirroredStrategy, self).__init__(extended) - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. - self._default_device = self._workers[0] - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerReplica( - {d: i for i, d in enumerate(devices)}) +class MirroredExtended(CoreMirroredExtended): + """Implementation of (contrib) MirroredStrategy.""" - def _create_variable(self, next_creator, *args, **kwargs): - """Create a mirrored variable. See `DistributionStrategy.scope`.""" - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - - def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - def initial_value_fn(device=d): - if context.executing_eagerly(): - init_value = index[devices[0]].value() - return array_ops.identity(init_value) - else: - with ops.device(device): - init_value = index[devices[0]].initial_value - return array_ops.identity(init_value) - kwargs["initial_value"] = initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - # Don't record operations (e.g. other variable reads) during - # variable creation. - with tape.stop_recording(): - v = next_creator(*args, **kwargs) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index - - return _create_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + def __init__(self, + container_strategy, + devices=None, + num_gpus_per_worker=None, + cross_device_ops=None, + auto_shard_dataset=False): + if devices is None: + devices = mirrored_strategy.all_local_devices(num_gpus_per_worker) + elif num_gpus_per_worker is not None: + raise ValueError( + "Must only specify one of `devices` and `num_gpus_per_worker`.") + super(MirroredExtended, self).__init__(container_strategy, devices, + cross_device_ops) + self._auto_shard_dataset = auto_shard_dataset - def distribute_dataset(self, dataset_fn): - if self._cluster_spec: - return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_devices, - auto_shard=self._auto_shard_dataset) + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch. + + This implementation is different than the one in + `tf.distribute.MirroredStrategy` for purposes of backward compatibility. + We treat the incoming dataset's batch size as per replica batch size. + + Args: + dataset: `tf.data.Dataset` for input. + Returns: + An `InputIterator` which returns inputs for each step of the computation. + """ + if self._local_mode: + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, self._devices)] else: + worker_device_pairs = self._worker_devices + return values.DatasetIterator(dataset, worker_device_pairs) + + def _distribute_dataset(self, dataset_fn): + if self._local_mode: return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices) - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) - for (name, output) in ctx.last_step_outputs.items(): - # Convert all outputs to tensors, potentially from `DistributedValues`. - ctx.last_step_outputs[name] = self.unwrap(output) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerReplica container. - if aggregation is variables_lib.VariableAggregation.NONE: - last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerReplica) - else: - assert len(output) == 1 - last_step_tensor_outputs_dict[name] = output[0] - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _broadcast(self, tensor, destinations): - # TODO(josh11b): In eager mode, use one thread per device, or async mode. - return self._get_cross_tower_ops().broadcast(tensor, destinations or - self._devices) - - def _call_for_each_replica(self, fn, args, kwargs): - return _call_for_each_replica(self, fn, args, kwargs) - - def map(self, map_over, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - index = {} - for i, m in enumerate(map_over): - d = self._devices[i % len(self._devices)] - with ops.device(d): - l = index.get(d, []) - l.append(fn(m, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - index[d] = l - # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput - # in addition to PerReplica data. - return values.PerReplica({k: values.MapOutput(v) for k, v in index.items()}) - - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - del task_type, task_id - - if session_config: - session_config.isolate_session_state = True - - if cluster_spec: - self._initialize_multi_worker(self._num_gpus, cluster_spec) - - if self._cross_tower_ops is None: - if self._cluster_spec: - # It currently cannot detect the toplogy of remote workers. So we - # hard-code the multi-worker all-reduce algorithm for now. - if len(self._workers) == 1: - # The default is "nccl". - self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossDeviceOps() - else: - # The default is hierarchical reduce and broadcast. - self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( - self._workers, self._num_gpus) - else: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) - - def _get_cross_tower_ops(self): - if self._cross_tower_ops is None: - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()) - return self._cross_tower_ops - - def _reduce(self, aggregation, value, destinations): - assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerReplica or - # Mirrored values. For example, the same value could be present on all - # replicas in which case `value` would be a single value or value could - # be 0. - return _reduce_non_distributed_value(self, aggregation, value, - destinations) - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - value = value.get(self._devices[0]) - if isinstance(value, (int, float)): - return value - return self.broadcast(value, destinations) - return self._get_cross_tower_ops().reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._devices[0]), d) - for v, d in value_destination_pairs] - return self._get_cross_tower_ops().batch_reduce(aggregation, - value_destination_pairs) - - def _update(self, var, options, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - assert isinstance(var, values.DistributedVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - assert isinstance(colocate_with, list) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - # TODO(josh11b): In eager mode, use one thread per device. - updates = {} - for d in colocate_with: - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - updates[d] = fn(*values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - if isinstance(replica_local_var, values.ReplicaLocalVariable): - return replica_local_var._get_cross_replica() # pylint: disable=protected-access - assert isinstance(replica_local_var, values.Mirrored) - return array_ops.identity(replica_local_var.get()) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_device_set: - return [val.get(device=d) for d in self._devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] - - def value_container(self, val): - return values.value_container(val) - - @property - def num_replicas(self): - return len(self._devices) - - @property - def num_replicas_in_sync(self): - return len(self._devices) - - def _worker_device_index(self): - return self._device_index - - @property - def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._devices) - - @property - def parameter_devices(self): - return list(self._devices) - - @property - def between_graph(self): - return False - - @property - def should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - def non_slot_devices(self, var_list): - del var_list - return list(self._devices) - - def _get_devices_from(self, colocate_with=None): - if colocate_with is None: - return self._devices else: - return cross_tower_ops_lib.get_devices_from(colocate_with) - - class _MirroredReplicaThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, - **kwargs): - super(MirroredStrategy._MirroredReplicaThread, self).__init__() # pylint: disable=protected-access - self.coord = coord - self.distribution = dist - self.device = device - self.replica_id = dist.worker_devices.index(device) - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # ReplicaContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_replica_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - # pylint: disable=protected-access - self.context_mode = context.context()._eager_context.mode - if not context.context()._context_handle: - context.context()._initialize_handle_and_devices() - self.context_device_policy = ( - pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - context.context()._context_handle)) - self.graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] - self._captured_var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.replica_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "replica_%d/" % self.replica_id - - def run(self): - # pylint: disable=protected-access - self.graph._variable_creator_stack = self._variable_creator_stack - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - with self.coord.stop_on_exception(), \ - context.context()._mode(self.context_mode), \ - context.context().device_policy(self.context_device_policy), \ - _enter_graph(self.graph), \ - MirroredReplicaContext(self.distribution, self.replica_id), \ - ops.device(self.device), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._captured_var_scope, reuse=self.replica_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - -class MirroredReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext used in MirroredStrategy.call_for_each_replica(). - - Opened in `_MirroredReplicaThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.call_for_each_replica()`). - """ - - def _merge_call(self, fn, args, kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredReplicaThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result - - @property - def device(self): - raise RuntimeError("Use .devices instead") + return values.MultiWorkerDataset( + functools.partial(self._call_dataset_fn, dataset_fn), + self._worker_devices, + auto_shard=self._auto_shard_dataset) + # TODO(priyag): Delete this once all strategies use global batch size. @property - def devices(self): - distribute_lib.require_replica_context(self) - return [self._distribution_strategy.worker_devices[self._replica_id]] + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 0dbf6ba0567a3637d3ebfca6df05804dd61e07c3..36be5c83f8bafb6c934d1d7682b5227b1f71c089 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,22 +20,27 @@ from __future__ import print_function import sys +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core @@ -46,8 +51,6 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -56,253 +59,229 @@ from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=["graph", "eager"])) +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): - def _get_distribution_strategy(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - if context.num_gpus() > 1: - devices = ["/device:GPU:0", "/device:GPU:1"] - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - return mirrored_strategy.MirroredStrategy(devices) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - soft_placement = not GPU_TEST - print("testMinimizeLossGraph soft_placement:", soft_placement) - self._test_minimize_loss_graph( - self._get_distribution_strategy(), soft_placement=soft_placement) - - def testMapReduce(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_device_index(self._get_distribution_strategy()) - - def testReplicaId(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_replica_id(self._get_distribution_strategy()) - - def testNumReplicas(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy().num_replicas) - - def testNumReplicasInSync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy(). - num_replicas_in_sync) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testRunRegroupError(self): - - def run_fn(device_id): + def testNumReplicasInSync(self, distribution): + self.assertEqual(2, distribution.num_replicas_in_sync) + + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) + + def testRunRegroupError(self, distribution): + def run_fn(): + replica_id = int(self.evaluate(_replica_id())) # Generates a list with different lengths on different devices. # Will fail in _regroup() (if more than one device). - return list(range(device_id)) - - dist = self._get_distribution_strategy() - with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_replica(run_fn, args=(dist.worker_device_index,)) - - @test_util.run_in_graph_and_eager_modes - def testReduceToCpu(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return device_id - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica( - run_fn, args=(dist.worker_device_index,)) - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - expected = sum(range(len(dist.worker_devices))) - self.assertEqual(expected, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes - def testReduceOnlyFirstReplicaUpdates(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return constant_op.constant(3 + 5 * device_id) - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica( - run_fn, args=(dist.worker_device_index,)) - reduced = dist.reduce( - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - self.assertEqual(3, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes() - def testReduceToMultipleDestinations(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - devices = ["/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - 1.0, - destinations=["/device:CPU:0", "/device:GPU:0"]) - unwrapped = dist.unwrap(reduced) - self.assertEqual(2, len(unwrapped)) - self.assertEqual(1.0, self.evaluate(unwrapped[0])) + return list(range(replica_id)) + + with distribution.scope(), self.assertRaises(AssertionError): + distribution.extended.call_for_each_replica(run_fn) + + def testReduceToCpu(self, distribution): + with distribution.scope(): + result = distribution.extended.call_for_each_replica(_replica_id) + reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result) + expected = sum(range(distribution.num_replicas_in_sync)) + self.assertEqual(expected, self.evaluate(reduced)) + + def testMakeInputFnIterator(self, distribution): + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values) + + def testGlobalStepUpdate(self, distribution): + self._test_global_step_update(distribution) + + +def one_device_combinations(): + return combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_cpu, + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_cpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph", "eager"]) + + +class MirroredOneDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(one_device_combinations()) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) + + @combinations.generate(one_device_combinations()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) + + @combinations.generate(one_device_combinations()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) +class MirroredStrategyVariableCreatorStackTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) + def testCreatorStacksAreThreadLocal(self, distribution): + def model_fn(): + replica_id_str = str(self.evaluate(_replica_id())) + + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + replica_id_str + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + distribution.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = distribution.extended.call_for_each_replica(model_fn) + result = distribution.unwrap(result) + expected = ("main_thread:thread_0", "main_thread:thread_1") + self.assertEqual(expected, result) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyVariableCreationTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + for d in var.devices: + self.assertEqual(d, var.get(d).device) + + def testVariableInFuncGraph(self, distribution): + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSingleVariable(self): - self._skip_eager_if_gpus_less_than(1) + self._test_mv_properties(v1, "foo:0") + self._test_mv_properties(v2, "bar:0") + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of - # special variable_creator functions used by `dist.call_for_each_replica`. + # special variable_creator functions used by + # `distribution.extended.call_for_each_replica`. v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnnamedVariable(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "foo:0") + def testUnnamedVariable(self, distribution): def model_fn(): v = variable_scope.variable(1.0) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - # Default name of "Variable" will be used. - self.assertEquals("Variable:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariables(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "Variable:0") + def testMultipleVariables(self, distribution): def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals("foo" + str(i) + ":0", v.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariablesWithSameCanonicalName(self): - self._skip_eager_if_gpus_less_than(1) + self._test_mv_properties(v, "foo" + str(i) + ":0") + def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals(4, len(result)) - self.assertEquals("foo/bar:0", result[0].name) - self.assertEquals("foo_1/bar:0", result[1].name) - self.assertEquals("foo_1/bar_1:0", result[2].name) - self.assertEquals("foo/bar_1:0", result[3].name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testVariableWithSameCanonicalNameAcrossThreads(self): - self._skip_eager_if_gpus_less_than(1) - - def model_fn(device_id): - v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v + self.assertEqual(4, len(result)) + self.assertEqual("foo/bar:0", result[0].name) + self.assertEqual("foo_1/bar:0", result[1].name) + self.assertEqual("foo_1/bar_1:0", result[2].name) + self.assertEqual("foo/bar_1:0", result[3].name) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) + def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) + ds_context.get_replica_context().merge_call(lambda _: _) + return v - with dist.scope(): - result = dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. - self.assertEquals("foo_0:0", result.name) + self.assertEqual("foo_0:0", result.name) - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithLayers(self): - self._skip_eager_if_gpus_less_than(1) + def testWithLayers(self, distribution): def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) @@ -310,17 +289,14 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - ds = dist.distribute_dataset( + ds = distribution.distribute_dataset( lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) if context.executing_eagerly(): iterator = ds.make_one_shot_iterator() @@ -330,26 +306,23 @@ class MirroredStrategyVariableCreationTest(test.TestCase): features = iterator.get_next() - with dist.scope(): - result = dist.call_for_each_replica(model_fn, args=(features,)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica( + model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) self.assertIsInstance(bias, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) + def testWithVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -363,37 +336,31 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): v = variable_scope.variable(1.0, name="var-main0") - self.assertEquals("var-main0:0", v.name) + self.assertEqual("var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("var0:0", v0.name) + self.assertEqual("var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("common/var1:0", v1.name) + self.assertEqual("common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertEqual("common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithGetVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) + def testWithGetVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -405,33 +372,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): with variable_scope.variable_scope("main"): v = variable_scope.get_variable("var-main0", [1]) - self.assertEquals("main/var-main0:0", v.name) + self.assertEqual("main/var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var0:0", v0.name) + self.assertEqual("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var1:0", v1.name) + self.assertEqual("main/common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("main/common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, - v2.aggregation) + self.assertEqual("main/common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, + v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("main/common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testOnlyFirstReplicaUpdatesVariables(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("main/common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + def testOnlyFirstReplicaUpdatesVariables(self, distribution): def create_fn(): aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA v0 = variable_scope.variable( @@ -447,71 +409,73 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1 devices = ["/device:GPU:0", "/device:CPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - v0, v1 = dist.call_for_each_replica(create_fn) + with distribution.scope(): + v0, v1 = distribution.extended.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) self.evaluate(v1.initializer) self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) + + def replica_id_plus_one(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) # Update using the assign_add member function. - def update_member_fn(device_id): - update0 = v0.assign_add(5.0 * (device_id + 1)) - update1 = v1.assign_add(7.0 * (device_id + 1)) + def update_member_fn(): + update0 = v0.assign_add(5.0 * replica_id_plus_one()) + update1 = v1.assign_add(7.0 * replica_id_plus_one()) return update0, update1 - update0a, update1a = dist.call_for_each_replica( - update_member_fn, args=(dist.worker_device_index,)) + update0a, update1a = distribution.extended.call_for_each_replica( + update_member_fn) # Update "sync on read" variable. - self.evaluate(dist.group(update0a)) + self.evaluate(distribution.group(update0a)) self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) # Writes are not synchronized for "sync on read" variables, # so device[1] can end up with a different value. self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) # Always reads from device 0. - self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1a)) + self.evaluate(distribution.group(update1a)) self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) # Writes are synchronized for v1, only the argument to assign_add on # device[0] is used. self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0 + 7.0, self.evaluate( + distribution.extended.read_var(v1))) # Update using state_ops.assign_add global function. - def update_state_ops_fn(device_id): - update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) - update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + def update_state_ops_fn(): + update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) + update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) return update0, update1 - update0b, update1b = dist.call_for_each_replica( - update_state_ops_fn, args=(dist.worker_device_index,)) - self.evaluate(dist.group(update0b)) + update0b, update1b = distribution.extended.call_for_each_replica( + update_state_ops_fn) + self.evaluate(distribution.group(update0b)) # Update "sync on read" variable. self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1b)) + self.evaluate(distribution.group(update1b)) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate( + distribution.extended.read_var(v1))) + + def testNoneSynchronizationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -520,12 +484,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "v", [1], synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testNoneSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -535,23 +495,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): name="v", synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable synchronization mode: Invalid for " "variable: v"): variable_scope.variable(1.0, name="v", synchronization="Invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -560,12 +512,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -575,55 +523,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testThreeDevices(self): - self._skip_eager_if_gpus_less_than(2) - - def model_fn(): - v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNonMatchingVariableCreation(self): - self._skip_eager_if_gpus_less_than(1) - + def testNonMatchingVariableCreation(self, distribution): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): names = values.DistributedValues({ "/device:CPU:0": "foo", "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_replica(model_fn, args=(names,)) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariable(self): - self._skip_eager_if_gpus_less_than(1) + _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) + def testReplicaLocalVariable(self, distribution): all_v_sum = {} all_v_mean = {} components_sum = {} components_mean = {} - def model_fn(device_id): + def model_fn(): + replica_id = self.evaluate(_replica_id()) v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -634,26 +555,22 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) self.assertTrue(isinstance(v_mean, values.ReplicaLocalVariable)) - updates = [v_sum.assign_add(2.0 + device_id), - v_mean.assign(6.0 * device_id)] - all_v_sum[device_id] = v_sum - all_v_mean[device_id] = v_mean + updates = [v_sum.assign_add(2.0 + replica_id), + v_mean.assign(6.0 * replica_id)] + all_v_sum[replica_id] = v_sum + all_v_mean[replica_id] = v_mean c_sum = v_sum.get() c_mean = v_mean.get() - components_sum[device_id] = c_sum - components_mean[device_id] = c_mean + components_sum[replica_id] = c_sum + components_mean[replica_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( - dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,))) + distribution.extended.call_for_each_replica(model_fn)) # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) @@ -668,10 +585,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Apply updates self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)]) expected_sum = 0.0 expected_mean = 0.0 - for i, d in enumerate(dist.worker_devices): + for i, d in enumerate(distribution.extended.worker_devices): # Should see different values on different devices. v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) @@ -681,69 +598,125 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected = i * 6.0 self.assertEqual(expected, v_mean_value) expected_mean += expected - expected_mean /= len(dist.worker_devices) + expected_mean /= len(distribution.extended.worker_devices) # Without get(device), should return the value you get by # applying the reduction across all replicas (whether you use # read_var(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate( + distribution.extended.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate( + distribution.extended.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + # TODO(priyag): Update this test to work in eager mode as well. + def testDynamicRnnVariables(self, distribution): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + # Two variables are created by the RNN layer. + self.assertEqual(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = distribution.unwrap(v) + self.assertStartsWith(v1._op.name, "replica_1/") + + def testReplicaLocalVariableUpdate(self, distribution): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + return v_sum + + def update(var, value): + return var.assign(value) + + with distribution.scope(): + ret_v_sum = distribution.extended.call_for_each_replica(model_fn) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values before running the update ops. + self.assertEqual(1.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + update_ops = distribution.extended.update( + ret_v_sum, update, args=(5.0,), group=False) + self.evaluate(update_ops) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values after running the update ops. + self.assertEqual(5.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(10.0, self.evaluate(ret_v_sum)) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) +class MirroredStrategyNameScopeTest(test.TestCase): # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. - def testNameScope(self): + def testNameScope(self, distribution): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(1.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): - result = dist.call_for_each_replica(model_fn) - self.assertEquals(2, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("main/foo/" + name + ":0", v0.name) - self.assertEquals("main/replica_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("main/foo/" + name + ":0", v0.name) + self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) - def testWithDefaultName(self): + def testWithDefaultName(self, distribution): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(2.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn) - self.assertEquals(2, len(result)) + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("foo/" + name + ":0", v0.name) - self.assertEquals("replica_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("foo/" + name + ":0", v0.name) + self.assertEqual("replica_1/foo/" + name + ":0", v1.name) # variable_scope.variable() respects name scopes when creating # variables. On the other hand variable_scope.get_variable() ignores name # scopes when creating variables. We test both methods of creating variables # to make sure that we have the same variable names in both cases. - def testNameScopeWithVariable(self): + def testNameScopeWithVariable(self, distribution): def in_cross_replica(_): c = variable_scope.variable(1.0, name="c") return c @@ -751,32 +724,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_replica(model_fn) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("main/a:0", a0.name) - self.assertEquals("main/a/replica_1:0", a1.name) - self.assertEquals("main/b:0", b0.name) - self.assertEquals("main/b/replica_1:0", b1.name) - self.assertEquals("main/foo/c:0", c0.name) - self.assertEquals("main/foo/c/replica_1:0", c1.name) - - def testNameScopeWithGetVariable(self): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("main/a:0", a0.name) + self.assertEqual("main/a/replica_1:0", a1.name) + self.assertEqual("main/b:0", b0.name) + self.assertEqual("main/b/replica_1:0", b1.name) + self.assertEqual("main/foo/c:0", c0.name) + self.assertEqual("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self, distribution): def in_cross_replica(_): c = variable_scope.get_variable("c", [1]) return c @@ -784,118 +753,80 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_replica(model_fn) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("a:0", a0.name) - self.assertEquals("a/replica_1:0", a1.name) - self.assertEquals("b:0", b0.name) - self.assertEquals("b/replica_1:0", b1.name) - self.assertEquals("c:0", c0.name) - self.assertEquals("c/replica_1:0", c1.name) - - def testDynamicRnnVariables(self): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("a:0", a0.name) + self.assertEqual("a/replica_1:0", a1.name) + self.assertEqual("b:0", b0.name) + self.assertEqual("b/replica_1:0", b1.name) + self.assertEqual("c:0", c0.name) + self.assertEqual("c/replica_1:0", c1.name) + + +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2) + ], + mode=["graph", "eager"])) +class MirroredThreeDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testThreeDevices(self, distribution): def model_fn(): - inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) - cell_fw = rnn_cell_impl.LSTMCell(300) - cell_bw = rnn_cell_impl.LSTMCell(300) - (outputs, _) = rnn.bidirectional_dynamic_rnn( - cell_fw, - cell_bw, - inputs, - dtype=dtypes.float32) - return outputs - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn) - # Two variables are created by the RNN layer. - self.assertEquals(2, len(result)) - for v in result: - self.assertIsInstance(v, values.DistributedValues) - _, v1 = dist.unwrap(v) - self.assertStartsWith(v1.name, "replica_1/") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariableUpdate(self): - with context.graph_mode(): - - def model_fn(): - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) - return v_sum - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]) - - def update(var, value): - return var.assign(value) - - with dist.scope(): - ret_v_sum = dist.call_for_each_replica(model_fn) - update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) - - # Initialize variables. - self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values before running the update ops. - self.assertEquals(1.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(2.0, self.evaluate(ret_v_sum)) + v = variable_scope.variable(1.0, name="foo") + ds_context.get_replica_context().merge_call(lambda _: _) + return v - # Apply updates. - self.evaluate(update_ops) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values after running the update ops. - self.assertEquals(5.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(10.0, self.evaluate(ret_v_sum)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEqual("foo:0", result.name) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredVariableUpdateTest(test.TestCase): # The following tests check assign, assign_add and assign_sub on Mirrored # variables in replica and cross replica context. - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithoutAggregationType(self): + def testAssignMirroredVarReplicaContextWithoutAggregationType(self, + distribution): # Test that we always have an aggregation type set on the mirrored variable # if we assign to it in replica mode. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -905,23 +836,19 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "You must specify an aggregation method to update a " "MirroredVariable in Replica Context."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSum(self): + def testAssignMirroredVarReplicaContextWithSum(self, distribution): # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -930,219 +857,184 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " - "with the given aggregation VariableAggregation.SUM."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + "with the given reduce op ReduceOp.SUM."): + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) - self.assertEquals(6.0, mirrored_var_result) + self.assertEqual(6.0, mirrored_var_result) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(0.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(0.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(5.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) # read_value == True mirrored_var_result = self.evaluate( mirrored_var.assign_add(6.0, read_value=True)) - self.assertEquals(7.0, mirrored_var_result) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(7.0, mirrored_var_result) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) # read_value == False self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(1.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(1.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(6.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(6.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(5.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) - self.assertEquals(3.0, mirrored_var_result) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(3.0, mirrored_var_result) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(4.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) - self.assertEquals(4.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.0, self.evaluate(mirrored_var)) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def testAssignMirroredVarInitializer(self): + def testAssignMirroredVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1150,17 +1042,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.assertFalse(self.evaluate(mirrored_var.is_initialized())) self.evaluate(mirrored_var.initializer) self.assertTrue(self.evaluate(mirrored_var.is_initialized())) - def testAssignReplicaLocalVarInitializer(self): + def testAssignReplicaLocalVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1172,11 +1061,9 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica( + model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.assertFalse(self.evaluate(replica_local_var.is_initialized())) @@ -1184,17 +1071,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(self.evaluate(replica_local_var.is_initialized())) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class ReplicaLocalVariableAssignTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarSumAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarSumAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1202,18 +1086,16 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.SUM) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. - self.assertEqual(2.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(2.0, self.evaluate( + distribution.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1221,11 +1103,10 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # On reading the replica local var we should get the assigned value back. # The value on all the replicas are added before being returned by # `read_var`. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarMeanAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1233,23 +1114,22 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. - self.assertEqual(1.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(1.0, self.evaluate( + distribution.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) class MockModel(object): @@ -1283,24 +1163,25 @@ class MiniModel(keras_training.Model): return self.fc(inputs) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyDefunTest(test.TestCase): - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - - def _call_and_check(self, model_fn, inputs, expected_result, defuns, - two_variables=False): + def _call_and_check(self, distribution, model_fn, inputs, expected_result, + defuns, two_variables=False): cpu_dev = device_util.canonicalize("CPU:0") gpu_dev = device_util.canonicalize("GPU:0") devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_replica(model_fn, args=[mock_model] + inputs) + result = distribution.extended.call_for_each_replica( + model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1312,17 +1193,15 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_replica_graph_functions = dist.call_for_each_replica( - defun.get_concrete_function, args=[mock_model] + inputs) + per_replica_graph_functions = ( + distribution.extended.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs)) for device in devices: graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) - @test_util.run_in_graph_and_eager_modes() - def testVariableInDefun(self): - self._skip_eager_if_gpus_less_than(1) - + def testVariableInDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1330,12 +1209,9 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return times_two(mock_model) - self._call_and_check(model_fn, [], 2.5, [times_two]) - - @test_util.run_in_graph_and_eager_modes() - def testVariableInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 2.5, [times_two]) + def testVariableInNestedDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1347,12 +1223,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return two_x_plus_one(mock_model) - self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) - - @test_util.run_in_graph_and_eager_modes() - def testTwoVariablesInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 3.5, + [times_two, two_x_plus_one]) + def testTwoVariablesInNestedDefun(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1364,12 +1238,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return fn2(mock_model) - self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) - - @test_util.run_in_graph_and_eager_modes() - def testGradientTapeOverNestedDefuns(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2], + two_variables=True) + def testGradientTapeOverNestedDefuns(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1385,13 +1257,10 @@ class MirroredStrategyDefunTest(test.TestCase): [v.get() for v in mock_model.variables]) return grads - self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2], two_variables=True) - @test_util.run_in_graph_and_eager_modes() - def testPassPerReplica(self): - self._skip_eager_if_gpus_less_than(1) - + def testPassPerReplica(self, distribution): @function.defun def fn1(mock_model, factor): return mock_model(factor) @@ -1399,18 +1268,10 @@ class MirroredStrategyDefunTest(test.TestCase): factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, "GPU:0": 3.0 * 1.25}) - self._call_and_check(fn1, [factors], expected_result, [fn1]) - - @test_util.run_in_graph_and_eager_modes() - def testTrain(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) - cpu_dev = device_util.canonicalize("CPU:0") - gpu_dev = device_util.canonicalize("GPU:0") - devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - - with dist.scope(): + def testTrain(self, distribution): + with distribution.scope(): mock_model = MiniModel() mock_model.call = function.defun(mock_model.call) @@ -1420,10 +1281,11 @@ class MirroredStrategyDefunTest(test.TestCase): gradients_fn = backprop.implicit_grad(loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = dist.call_for_each_replica(gradients_fn, args=(None,)) + grads_and_vars = distribution.extended.call_for_each_replica( + gradients_fn, args=(None,)) optimizer = gradient_descent.GradientDescentOptimizer(0.25) - update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) @@ -1435,30 +1297,82 @@ class MirroredStrategyDefunTest(test.TestCase): self.assertAllEqual([0.5], updated_var_values[1]) +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker= + context.num_gpus()), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()), + required_gpus=1) + ], + mode=["graph"])) class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): - def _get_distribution_strategy(self): + def _configure_distribution_strategy(self, distribution): cluster_spec = server_lib.ClusterSpec({ "worker": ["/job:worker/task:0", "/job:worker/task:1"] }) - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure(cluster_spec=cluster_spec) - return strategy - - def test_num_replicas_in_sync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") + distribution.configure(cluster_spec=cluster_spec) - strategy = self._get_distribution_strategy() + def test_num_replicas_in_sync(self, distribution): + self._configure_distribution_strategy(distribution) # We calculate the total number of gpus across the workers(2) specified in # the cluster spec. - self.assertEqual(context.num_gpus() * 2, strategy.num_replicas_in_sync) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy(), - learning_rate=0.05) + self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) + + def testMinimizeLossGraph(self, distribution): + self._configure_distribution_strategy(distribution) + self._test_minimize_loss_graph(distribution, learning_rate=0.05) + + def testDeviceScope(self, distribution): + """Test the device scope of multi-worker MirroredStrategy.""" + self._configure_distribution_strategy(distribution) + with distribution.scope(): + a = constant_op.constant(1.) + with ops.device("/cpu:0"): + b = constant_op.constant(1.) + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") + + def testMakeInputFnIterator(self, distribution): + self._configure_distribution_strategy(distribution) + dataset_fn = lambda: dataset_ops.Dataset.range(100) + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [[i+j for j in range(num_gpus)] * num_workers + for i in range(0, 100, num_gpus)] + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess) + + def testUpdateConfigProto(self, distribution): + distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) class MultiWorkerMirroredStrategyTestWithChief( @@ -1478,6 +1392,19 @@ class MultiWorkerMirroredStrategyTestWithChief( strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) + def testMinimizeLossGraphCoreMirroredStrategy(self): + strategy = mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py deleted file mode 100644 index b5d393fd0dc8d3524bf356b7e60480d6056fd550..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for class MirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context - - -class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) - - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - -class VariableCreatorStackTest(test.TestCase): - - def testCreatorStacksAreThreadLocal(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - - def model_fn(device_id): - assert isinstance(device_id, int) - - def thread_creator_fn(next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + ":thread_" + str(device_id) - - with variable_scope.variable_creator_scope(thread_creator_fn): - # Create a variable in this scope. - v = variable_scope.variable(1.0) - - # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - def main_thread_creator(next_creator, *args, **kwargs): - # We are not using the underlying next_creator for test purposes. - del next_creator, args, kwargs - return "main_thread" - - with context.graph_mode(), \ - dist.scope(), \ - variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_replica( - model_fn, args=(dist.worker_device_index,)) - result = dist.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] - self.assertEquals(expected, result) - - -class MultiWorkerMirroredStrategyTest(test.TestCase): - - def testDeviceScope(self): - """Test the device scope of multi-worker MirroredStrategy.""" - with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure( - cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device("/cpu:0"): - b = constant_op.constant(1.) - self.assertEqual(a.device, "/job:worker/task:0") - self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 7ecc852d20508cc7063f3598c9fef03d6ce536a5..8f13e9153ea7a951dd722c4549882c97e79b57fe 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -32,7 +32,8 @@ from tensorflow.python.training import moving_averages all_combinations = combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph"]) @@ -138,6 +139,27 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], var.eval()) + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 8eec3dc0f6ec0676353c7434d203e017b9aab80d..147c9b83f866fd364ea23cf7988692a7b5f61b9c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy +import json +import os import threading import numpy as np @@ -271,7 +274,6 @@ class MultiWorkerTestBase(test.TestCase): return config - def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, **kwargs): result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) @@ -303,3 +305,101 @@ class MultiWorkerTestBase(test.TestCase): for t in threads: t.join() self.assertEqual(self._result, len(threads)) + + +class MockOsEnv(collections.Mapping): + """A class that allows per-thread TF_CONFIG.""" + + def __init__(self, *args): + self._dict = dict() + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default=None): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self._dict, key, default) + + def __getitem__(self, key): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self._dict, key) + + def __setitem__(self, key, val): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self._dict, key, val) + + def __iter__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + for x in self._thread_local.dict.items(): + yield x + for x in self._dict.items(): + yield x + + def __len__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + return self._thread_local.dict.__len__() + self._dict.__len__() + + +class IndependentWorkerTestBase(test.TestCase): + """Testing infra for independent workers.""" + + def setUp(self): + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, 'environ', + self._mock_os_env) + super(IndependentWorkerTestBase, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(IndependentWorkerTestBase, self).tearDown() + + def _task_thread(self, task_fn, tf_config, *args, **kwargs): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) + + def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, + *args, **kwargs): + if task_type: + tf_config = { + 'cluster': cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + else: + tf_config = { + 'cluster': cluster_spec, + } + t = threading.Thread( + target=self._task_thread, + args=(task_fn, tf_config) + args, + kwargs=kwargs) + t.start() + return t + + def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, + **kwargs): + # The task_fn should create std_server by itself. + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, + *args, **kwargs) + threads[task_type].append(t) + return threads diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 1b4251b761b2b95b4e41fbd8c8d5e31e5e1b2d25..fdbfba4e04358451a46b23ef250dc7c534c855a0 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,14 +20,14 @@ from __future__ import print_function import six -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -41,7 +41,14 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # implementations? def __init__(self, device): - super(OneDeviceStrategy, self).__init__() + super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) + + +class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of OneDeviceStrategy.""" + + def __init__(self, container_strategy, device): + super(OneDeviceExtended, self).__init__(container_strategy) self._device = device self._default_device = device @@ -53,24 +60,40 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): if isinstance(colocate_with, six.string_types): with ops.device(colocate_with): return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and isinstance(colocate_with[0], six.string_types)): with ops.device(colocate_with[0]): return next_creator(*args, **kwargs) with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch.""" + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.DatasetIterator(dataset, worker_device_pairs) + + def _distribute_dataset(self, dataset_fn): return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), [self._device]) - def _broadcast(self, tensor, destinations): + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, + [distribute_lib.InputContext()]) + + def _broadcast_to(self, tensor, destinations): del destinations return tensor # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): + def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, + initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) @@ -82,7 +105,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs @@ -116,39 +139,24 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return ctx def _call_for_each_replica(self, fn, args, kwargs): - with ops.device(self._device), _OneDeviceReplicaContext(self): + strategy = self._container_strategy() + with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def map(self, map_over, fn, *args, **kwargs): - with ops.device(self._device): - return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - - def _reduce(self, aggregation, value, destinations): - del destinations - if not isinstance(value, values.MapOutput): - return value - l = value.get() - assert l - with ops.device(self._device): - if aggregation == vs.VariableAggregation.SUM: - return math_ops.add_n(l) - elif aggregation == vs.VariableAggregation.MEAN: - return math_ops.add_n(l) / len(l) - else: - assert False + def _reduce_to(self, reduce_op, value, destinations): + del reduce_op, destinations + return value - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, options, fn, var, *args, **kwargs) + return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -158,33 +166,43 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return array_ops.identity(replica_local_var) def _unwrap(self, value): - return [value] + return (value,) def value_container(self, value): return value @property - def num_replicas(self): - return 1 - - @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return 1 @property def worker_devices(self): - return [self._device] + return (self._device,) @property def parameter_devices(self): - return [self._device] + return (self._device,) def non_slot_devices(self, var_list): del var_list - return [self._device] + return (self._device,) + + @property + def experimental_should_init(self): + return True + + @property + def should_checkpoint(self): + return True - def _worker_device_index(self): - return 0 + @property + def should_save_summary(self): + return True + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @@ -192,12 +210,10 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) - - @property - def device(self): - raise RuntimeError("Use .devices instead") + self, + distribution_strategy, + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property def devices(self): - return [self._distribution_strategy.worker_devices[0]] + return self._distribution_strategy.extended.worker_devices diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 3fb92273924a665bf2a1ee5fc94b75273b8c5f78..d46cd6f529e363f76bfa2b22339add63530cfde8 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -35,12 +36,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - def testReplicaId(self): self._test_replica_id(self._get_distribution_strategy()) @@ -48,6 +43,20 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + @test_util.run_in_graph_and_eager_modes + def testMakeInputFnIterator(self): + d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = d.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, d.extended.worker_devices, expected_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 790b37f86010eba6bdc87e6424e55a97629c5d1a..2c7766f95fbcb7b68a53ad0052f21485c763a1db 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -30,8 +34,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_setter -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest _LOCAL_CPU = "/device:CPU:0" @@ -94,13 +96,21 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): ValueError: if `cluster_spec` is given but `task_type` or `task_id` is not. """ - super(ParameterServerStrategy, self).__init__() + super(ParameterServerStrategy, self).__init__( + ParameterServerExtended(self, num_gpus_per_worker)) + + +class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of ParameterServerStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + super(ParameterServerExtended, self).__init__(container_strategy) self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local(num_gpus_per_worker) # We typically don't need to do all-reduce in this strategy. - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + self._cross_device_ops = ( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_LOCAL_CPU)) def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, @@ -135,14 +145,14 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = [ + self._compute_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - self._compute_devices = [self._worker_device] + self._compute_devices = (self._worker_device,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) @@ -166,8 +176,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # The `_parameter_devices` is needed for the `parameter_devices` property # and is a list of all variable devices. Here parameter devices are all # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) + self._parameter_devices = tuple(map("/job:ps/task:{}".format, + range(num_ps_replicas))) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -189,28 +199,29 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _initialize_local(self, num_gpus_per_worker): """Initialize internal devices for local training.""" + self._worker_device = device_util.canonicalize("/device:CPU:0") # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = list( + self._compute_devices = tuple( map("/device:GPU:{}".format, range(num_gpus_per_worker))) else: - self._compute_devices = [_LOCAL_CPU] + self._compute_devices = (_LOCAL_CPU,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) # If there is only one GPU, put everything on that GPU. Otherwise, place # variables on CPU. if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 + assert len(self._compute_devices) == 1 self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] + self._parameter_devices = (_LOCAL_GPU_0,) else: self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] + self._parameter_devices = (_LOCAL_CPU,) self._is_chief = True self._cluster_spec = None @@ -221,15 +232,48 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "ParameterServerStrategy with compute_devices = %r, " "variable_device = %r", self._compute_devices, self._variable_device) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) - def _broadcast(self, tensor, destinations): - if not cross_tower_ops_lib.check_destinations(destinations): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + num_input_pipelines = multi_worker_util.worker_count( + self._cluster_spec, self._task_type) + else: + input_pipeline_id = 0 + num_input_pipelines = 1 + input_context = distribute_lib.InputContext( + num_input_pipelines=num_input_pipelines, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, [input_context]) + + def _broadcast_to(self, tensor, destinations): + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): + return tensor + if not cross_device_ops_lib.check_destinations(destinations): destinations = self._compute_devices - return self._cross_tower_ops.broadcast(tensor, destinations) + return self._cross_device_ops.broadcast(tensor, destinations) def _allow_variable_partition(self): return not context.executing_eagerly() @@ -237,7 +281,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_replicas_in_sync > 1: + if self._num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, @@ -293,39 +337,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica(self, fn, args, kwargs) + return mirrored_strategy._call_for_each_replica( + self._container_strategy(), fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: return if destinations is None: return - for d in cross_tower_ops_lib.get_devices_from(destinations): + for d in cross_device_ops_lib.get_devices_from(destinations): d_spec = tf_device.DeviceSpec.from_string(d) if d_spec.job == self._task_type and d_spec.task != self._task_id: raise ValueError( "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( - self, aggregation, value, destinations) - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return self.broadcast(value.get(self._compute_devices[0]), destinations) - return self._cross_tower_ops.reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._compute_devices[0]), d) - for v, d in value_destination_pairs] + self, reduce_op, value, destinations) + return self._cross_device_ops.reduce( + reduce_op, value, destinations=destinations) + + def _batch_reduce_to(self, reduce_op, value_destination_pairs): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) - return self._cross_tower_ops.batch_reduce(aggregation, - value_destination_pairs) + return self._cross_device_ops.batch_reduce(reduce_op, + value_destination_pairs) def _select_single_value(self, structured): """Select any single values in `structured`.""" @@ -349,30 +389,26 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -381,9 +417,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if isinstance(val, values.DistributedValues): # Return in a deterministic order. if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] + return tuple(val.get(device=d) for d in self._compute_devices) + return tuple(val.get(device=d) for d in sorted(val.devices)) + return (val,) def value_container(self, val): if (hasattr(val, "_aggregating_container") and @@ -398,11 +434,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # variables. return array_ops.identity(var) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the strategy class. The strategy object will be re-initialized if `cluster_spec` is given but @@ -433,48 +469,50 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, self._cluster_spec, task_type, task_id) - if not session_config or not self._cluster_spec: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) - session_config.isolate_session_state = False + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + if not self._cluster_spec: + updated_config.isolate_session_state = True + return updated_config + + updated_config.isolate_session_state = False - assert self._cluster_spec assert self._task_type assert self._task_id is not None # The device filters prevent communication between workers. if self._task_type not in ["chief", "worker"]: - return - del session_config.device_filters[:] - session_config.device_filters.extend( + return updated_config + del updated_config.device_filters[:] + updated_config.device_filters.extend( ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) + return updated_config @property - def num_replicas(self): - return len(self._compute_devices) - - @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._compute_devices) @property def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) + return self._compute_devices @property def parameter_devices(self): - return list(self._parameter_devices) + return self._parameter_devices def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) @property - def between_graph(self): + def experimental_between_graph(self): + # TODO(yuefengz): Should this return False in the local case? return True @property - def should_init(self): + def experimental_should_init(self): return self._is_chief @property @@ -484,3 +522,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): @property def should_save_summary(self): return self._is_chief + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 81a23c89030221a8a15bdedc796c50d9c518138c..83d7473666a65e438a1c0119d2a12bf54e53c8fc 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -25,14 +25,21 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -41,8 +48,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import training_util CHIEF = run_config.TaskType.CHIEF @@ -50,6 +55,13 @@ WORKER = run_config.TaskType.WORKER PS = run_config.TaskType.PS +def _get_replica_id_integer(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if isinstance(replica_id, ops.Tensor): + replica_id = tensor_util.constant_value(replica_id) + return replica_id + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -94,9 +106,8 @@ class ParameterServerStrategyTestBase( if num_gpus == 0: last_part_device = 'device:CPU:0' else: - last_part_device = ( - 'device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + last_part_device = ('device:GPU:%d' % replica_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -261,18 +272,16 @@ class ParameterServerStrategyTestBase( if 'CPU' in compute_device: replica_compute_device = '/device:CPU:0' else: - replica_compute_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_compute_device = ('/device:GPU:%d' % replica_id) replica_compute_device = device_util.canonicalize( replica_compute_device) if 'CPU' in variable_device: replica_variable_device = '/device:CPU:0' else: - replica_variable_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_variable_device = ('/device:GPU:%d' % replica_id) replica_variable_device = device_util.canonicalize( replica_variable_device) @@ -354,9 +363,9 @@ class ParameterServerStrategyTestBase( def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if 'chief' in d._cluster_spec.as_dict(): + if d.extended._cluster_spec: + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d.extended._cluster_spec.as_dict(): num_workers += 1 else: num_workers = 1 @@ -389,7 +398,7 @@ class ParameterServerStrategyTestBase( x, y, z, train_op = d.call_for_each_replica(model_fn) train_op = d.group(train_op) - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if task_id == 0: @@ -426,9 +435,9 @@ class ParameterServerStrategyTestBase( task_type, task_id, num_gpus) if task_type: # Multi-worker - assert hasattr(d, '_cluster_spec') and d._cluster_spec - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if CHIEF in d._cluster_spec.as_dict(): + assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d.extended._cluster_spec.as_dict(): num_workers += 1 else: # local @@ -472,8 +481,8 @@ class ParameterServerStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -481,11 +490,12 @@ class ParameterServerStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if (not task_type or - multi_worker_util.is_chief(d._cluster_spec, task_type, task_id)): + multi_worker_util.is_chief( + d.extended._cluster_spec, task_type, task_id)): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -508,8 +518,40 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_objects( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): @classmethod @@ -574,6 +616,73 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testMinimizeLossGraphLocal(self, num_gpus): self._test_minimize_loss_graph(None, None, num_gpus) + # TODO(priyag): Refactor this and other multi worker tests. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorLocal(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) # only one worker and pipeline for local. + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + + def testGlobalStepUpdate(self): + strategy = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=context.num_gpus()) + self._test_global_step_update(strategy) + + def testUpdateConfigProtoMultiWorker(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + + new_config = distribution.update_config_proto(config_proto) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1', '/job:ps'], + new_config.device_filters) + + # Verify isolate_session_state + self.assertFalse(new_config.isolate_session_state) + + def testUpdateConfigProtoLocal(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -616,9 +725,9 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.value_container(v) + w = distribution.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.call_for_each_replica(f) + distribution.extended.call_for_each_replica(f) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 3dc815f0371002bd3a8657f18ccc09a27bb14961..c928b6d9f1f21508edd753f94c38ab2723cc0a9f 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -94,7 +94,7 @@ class StandardSingleLossStep(StandardInputStep): def __call__(self): with self._distribution.scope(): - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): """Function to run one iteration with one input.""" gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 98cdb17b8ca2624ed8bbc55fc8a7fb7e76aa507e..d441b5af5f6aa41efde2c75d09d9589516c54992 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,16 +19,21 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -45,8 +50,7 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_replica() call, calls a # get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _raise_exception_fn) + ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -59,8 +63,7 @@ def _call_raises_fn(dist): # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_raises_fn) + ds_context.get_replica_context().merge_call(_call_raises_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -74,8 +77,7 @@ def _call_merge_raises_fn(dist): # get_replica_context().merge_call() that calls a call_for_each_replica() that # calls a get_replica_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_merge_raises_fn) + ds_context.get_replica_context().merge_call(_call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -114,8 +116,8 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -169,8 +171,8 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -189,40 +191,20 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_map_reduce(self, d, in_graph=None): - with d.scope(): - map_in = [constant_op.constant(i) for i in range(10)] - map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out, - "/device:CPU:0") - expected = 90 # 2 * (0 + 1 + ... + 9) - self.assertEqual(expected, observed.numpy()) - - def _test_device_index(self, d): - with d.scope(): - expected_devices = [False] * len(d.worker_devices) - - def mark_devices_fn(device_id): - self.assertLess(device_id, len(d.worker_devices)) - self.assertFalse(expected_devices[device_id]) - expected_devices[device_id] = True - - d.call_for_each_replica(mark_devices_fn, args=(d.worker_device_index,)) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) - def _test_replica_id(self, d): with d.scope(): - expected_devices = [False] * len(d.worker_devices) + expected_devices = [False] * len(d.extended.worker_devices) def mark_devices_fn(): - replica_id = ( - distribution_strategy_context.get_replica_context().replica_id) - self.assertLess(replica_id, len(d.worker_devices)) + replica_id = self.evaluate( + ds_context.get_replica_context().replica_id_in_sync_group) + self.assertLess(replica_id, len(d.extended.worker_devices)) self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True d.call_for_each_replica(mark_devices_fn) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + self.assertAllEqual(expected_devices, + [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): @@ -234,3 +216,78 @@ class DistributionTestBase(test.TestCase): dist.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): dist.call_for_each_replica(_merge_call_merge_raises_fn) + + def _input_fn_to_test_input_context(self, + dataset_fn, + expected_num_replicas_in_sync, + expected_num_input_pipelines, + expected_input_pipeline_id): + # Use a list of one element as counter so that it can be captured by the + # `_input_fn`. This counter is incremented by 1 each time an input_fn is + # called. We use this counter to check whether the `input_pipeline_id` + # matches the counter in the in-graph replication. + worker_id_counter = [0] + + def _input_fn(input_context): + """Input fn for testing.""" + self.assertIsNotNone(input_context) + self.assertEqual(expected_num_replicas_in_sync, + input_context.num_replicas_in_sync) + self.assertEqual(expected_num_input_pipelines, + input_context.num_input_pipelines) + if expected_input_pipeline_id is not None: + self.assertEqual(expected_input_pipeline_id, + input_context.input_pipeline_id) + else: + self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) + worker_id_counter[0] += 1 + + return dataset_fn() + + return _input_fn + + def _test_input_fn_iterator(self, iterator, devices, expected_values, + sess=None): + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + def _test_global_step_update(self, strategy): + with strategy.scope(): + global_step = variable_scope.get_variable( + "global_step", + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + train_op = global_step.assign_add(1) + value = global_step.read_value() + return train_op, value + + train_ops, value = strategy.call_for_each_replica(model_fn) + self.evaluate(strategy.group(train_ops)) + global_step_tensors = strategy.unwrap(value) + global_step_values = self.evaluate(global_step_tensors) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index f5b4531ba8c483e69f2a2b5539b27205efb9fc21..b6f5b492017fc7dfd329e69ad9ca418ae682bc4b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,31 +21,34 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" - - def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -130,8 +133,24 @@ class TPUStrategy(distribute_lib.DistributionStrategy): num_cores: Number of cores to use on the TPU. If None specified, then auto-detect the cores and topology of the TPU system. """ - super(TPUStrategy, self).__init__() + super(TPUStrategy, self).__init__(TPUExtended( + self, tpu_cluster_resolver, steps_per_run, num_cores)) + + @property + def steps_per_run(self): + """DEPRECATED: use .extended.steps_per_run instead.""" + return self._extended.steps_per_run + + +class TPUExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of TPUStrategy.""" + # Track what TPU devices have been initialized. + _initialized_devices = [] + + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, + num_cores=None): + super(TPUExtended, self).__init__(container_strategy) self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override @@ -143,16 +162,41 @@ class TPUStrategy(distribute_lib.DistributionStrategy): if "device:TPU:" in d.name} self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) - self._tpu_devices = sorted(device_map.keys()) + self._tpu_devices = tuple(sorted(device_map.keys())) # Only create variables for the number of replicas we're running. - self._tpu_devices = self._tpu_devices[:self.num_replicas] + self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run - self._require_static_shapes = True + # Initialize the TPU devices. + self._initialize_tpu() + + def _initialize_tpu(self): + """Initialize the TPU devices in a separate session and graph. + + We keep track of all the TPU devices that we're initialized as we should + only be running TPU initialize once for the entire process. + """ + master = self._tpu_cluster_resolver.master() + # Verify TPU has not already been initialized in this process. + if master in TPUExtended._initialized_devices: + logging.info("TPU master %s has already been initialized." % master) + return + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + self._configure(session_config) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + sess.run([tpu.initialize_system()]) + logging.info("Finized initializing TPU system.") + + # Update Strategy state to make sure we can track device initialization. + TPUExtended._initialized_devices.append(master) + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. @@ -214,7 +258,17 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return enqueue_op_per_host - def distribute_dataset(self, dataset_fn): + def _make_dataset_iterator(self, dataset): + """Make iterators for each of the TPU hosts.""" + + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + return values.DatasetIterator(dataset, worker_devices, + self._num_replicas_in_sync) + + def _distribute_dataset(self, dataset_fn): worker_devices = [ (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) @@ -225,12 +279,11 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. - def _run_steps_on_dataset(self, fn, multi_worker_iterator, iterations, - initial_loop_values=None): - + def _experimental_run_steps_on_iterator( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) - if any([not s.is_fully_defined() for s in shapes]): + if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " @@ -251,13 +304,13 @@ class TPUStrategy(distribute_lib.DistributionStrategy): initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() - def run_fn(*args, **kwargs): + + def run_fn(): """Single step on the TPU device.""" - del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -265,11 +318,6 @@ class TPUStrategy(distribute_lib.DistributionStrategy): else: return fn_result - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn - def iterate_on_tpu(): - return training_loop.repeat(iterations, run_fn, initial_loop_values) - # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer @@ -279,38 +327,70 @@ class TPUStrategy(distribute_lib.DistributionStrategy): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - replicate_inputs = [[]] * self.num_replicas - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + def rewrite_fn(*args): + """The rewritten step fn running on TPU.""" + del args + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) + + # If run_fn has tensor outputs, tpu.replicate returns a list of list. We + # will flatten it in this case. If run_fn has no tensor outputs, + # tpu.replicate returns a list of no_ops, we will keep the output as it + # is. + if isinstance(replicate_outputs[0], list): + replicate_outputs = nest.flatten(replicate_outputs) + + return replicate_outputs + + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn + assert isinstance(initial_loop_values, list) + initial_loop_values = initial_loop_values * self._num_replicas_in_sync + + # Put the while loop op on host 0. + with ops.device(self.get_host_cpu_device(0)): + replicate_outputs = training_loop.repeat(iterations, rewrite_fn, + initial_loop_values) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] - - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + if isinstance(replicate_outputs, list): + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [ + x for x in replicate_outputs if not isinstance(x, ops.Operation) + ] + + # Outputs are currently of the structure (flattened) + # [output0_device0, output1_device0, output2_device0, + # output0_device1, output1_device1, output2_device1, + # ...] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync + last_step_tensor_outputs = [ + last_step_tensor_outputs[i::output_num] for i in range(output_num) + ] + else: + # no tensors returned. + last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, take the first value + # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerReplica + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. - if aggregation is not variables_lib.VariableAggregation.NONE: + if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access @@ -320,33 +400,25 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - with _TPUReplicaContext(self): + with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def initialize(self): + def _initialize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - # TODO(jhseu): We need this hack because DistributionStrategies must be - # pickleable for copy.deepcopy(). Remove when initialize_system goes away. - graph = ops.get_default_graph() - tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - if tpu_init: - return tpu_init - graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, - tpu.initialize_system()) - return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - - def finalize(self): + return [] + + def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - return [tpu.shutdown_system()] + return [] def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. + # TODO(jhseu): Change this when we support model parallelism. return self._tpu_devices def _create_variable(self, next_creator, *args, **kwargs): @@ -383,12 +455,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_replicas) - elif aggregation != vs.VariableAggregation.SUM: + value *= (1. / self._num_replicas_in_sync) + elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) @@ -396,27 +468,22 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. - devices = cross_tower_ops_lib.get_devices_from(destinations) + devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return value[0] output = math_ops.add_n(value) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if should_group: + if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] @@ -431,9 +498,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - # TODO(josh11b): Need to implement _update_non_slot()! + return values.update_regroup(self, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -442,25 +507,21 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + return (val,) def value_container(self, value): return value - def _broadcast(self, tensor, destinations): + def _broadcast_to(self, tensor, destinations): del destinations return tensor - @property - def num_replicas(self): - return self._num_cores_override or self._tpu_metadata.num_cores - @property def num_hosts(self): return self._tpu_metadata.num_hosts @@ -470,15 +531,15 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return self._tpu_metadata.num_of_cores_per_host @property - def num_replicas_in_sync(self): - return self.num_replicas + def _num_replicas_in_sync(self): + return self._num_cores_override or self._tpu_metadata.num_cores @property - def between_graph(self): + def experimental_between_graph(self): return False @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -500,14 +561,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def non_slot_devices(self, var_list): return self._host_device - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._host_device), distribute_lib.UpdateContext( self._host_device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -521,17 +580,27 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def get_host_cpu_device(self, host_id): return self.get_host(host_id) + "/device:CPU:0" - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): del cluster_spec, task_type, task_id if session_config: - session_config.isolate_session_state = True - cluster_spec = self._tpu_cluster_resolver.cluster_spec() - if cluster_spec: - session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + session_config.CopyFrom(self._update_config_proto(session_config)) + + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + updated_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + return updated_config + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True class _TPUReplicaContext(distribute_lib.ReplicaContext): @@ -540,13 +609,14 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): # TODO(sourabhbajaj): Call for each tower should be updating this. def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) - - @property - def device(self): - raise RuntimeError("Use .devices instead") + self, + distribution_strategy, + # TODO(b/118385803): properly initialize replica_id, instead of always 0 + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property def devices(self): distribute_lib.require_replica_context(self) - return [self._distribution_strategy.worker_devices[self._replica_id]] + ds = self._distribution_strategy + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return (ds.extended.worker_devices[replica_id],) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 268393ee801b5f25bb5a7f061960b817c2d2ce5e..538b859f3d1ece55b460f6dbf8f01540a6013381 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -19,12 +19,15 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized -from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib @@ -34,10 +37,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -324,20 +327,20 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) - self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) - self.assertEquals(created_estimator_specs[device_id].loss, - merged_estimator_spec.loss.get(d)) - self.assertEquals(created_estimator_specs[device_id].train_op, - merged_estimator_spec.train_op.get(d)) + self.assertEqual(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEqual(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. - self.assertEquals(created_estimator_specs[device_id].scaffold, - merged_estimator_spec.scaffold.get(d)) + self.assertEqual(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() - self.assertEquals(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) + self.assertEqual(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) class PerReplicaDatasetTest(test.TestCase): @@ -568,7 +571,184 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): multi_worker_iterator.get_next() -class MirroredVariableTest(test.TestCase): +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, + input_contexts) + else: + iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, + split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -580,9 +760,9 @@ class MirroredVariableTest(test.TestCase): v, _, mirrored = _make_mirrored() - self.assertEquals(v[0].name, mirrored.name) - self.assertEquals(v[0].dtype, mirrored.dtype) - self.assertEquals(v[0].shape, mirrored.shape) + self.assertEqual(v[0].name, mirrored.name) + self.assertEqual(v[0].dtype, mirrored.dtype) + self.assertEqual(v[0].shape, mirrored.shape) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -592,9 +772,9 @@ class MirroredVariableTest(test.TestCase): mirrored = values.MirroredVariable(index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, mirrored.name) - self.assertEquals(v.dtype, mirrored.dtype) - self.assertEquals(v.shape, mirrored.shape) + self.assertEqual(v.name, mirrored.name) + self.assertEqual(v.dtype, mirrored.dtype) + self.assertEqual(v.shape, mirrored.shape) def _assign_mirrored(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -714,14 +894,13 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testFetchAMirroredVariable(self): - if context.num_gpus() < 1 or context.executing_eagerly(): - self.skipTest("A GPU is not available for this test or it's eager mode.") - - with self.session( - graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( - ["/device:GPU:0"]).scope(): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph"])) + def testFetchAMirroredVariable(self, distribution): + with self.session(graph=ops.Graph()) as sess, distribution.scope(): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) @@ -747,7 +926,7 @@ def _make_replica_local(method): return v, replica_local -class ReplicaLocalVariableTest(test.TestCase): +class ReplicaLocalVariablePropertiesTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -756,15 +935,14 @@ class ReplicaLocalVariableTest(test.TestCase): def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) - self.assertEquals(v[0].name, replica_local.name) - self.assertEquals(v[0].dtype, replica_local.dtype) - self.assertEquals(v[0].shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.SUM, - replica_local.aggregation) + self.assertEqual(v[0].name, replica_local.name) + self.assertEqual(v[0].dtype, replica_local.dtype) + self.assertEqual(v[0].shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.SUM, + replica_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -774,11 +952,32 @@ class ReplicaLocalVariableTest(test.TestCase): replica_local = values.ReplicaLocalVariable( index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, replica_local.name) - self.assertEquals(v.dtype, replica_local.dtype) - self.assertEquals(v.shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - replica_local.aggregation) + self.assertEqual(v.name, replica_local.name) + self.assertEqual(v.dtype, replica_local.dtype) + self.assertEqual(v.shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + replica_local.aggregation) + + def testTensorConversion(self): + with context.graph_mode(): + _, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -795,22 +994,15 @@ class ReplicaLocalVariableTest(test.TestCase): save_path, _ = self._save_return_saver(sess, var) return save_path - def _dist_scope(self): - return mirrored_strategy.MirroredStrategy(_devices).scope() - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalSumOneGraph(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - with self.cached_session(config=self.config) as sess: + def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 7. save_path, saver = self._save_return_saver(sess, replica_local) @@ -822,19 +1014,18 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalMeanOneGraph(self): + def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.cached_session(config=self.config) as sess: + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5. save_path, saver = self._save_return_saver(sess, replica_local) @@ -845,7 +1036,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _save_replica_local_mean(self): + def _save_replica_local_mean(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -854,7 +1045,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5 save_path = self._save(sess, replica_local) @@ -862,7 +1053,7 @@ class ReplicaLocalVariableTest(test.TestCase): self._assign_replica_local(_devices, v, [5., 6.]) return save_path - def _save_replica_local_sum(self): + def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local("sum") @@ -870,7 +1061,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 3.5 save_path = self._save(sess, replica_local) @@ -908,7 +1099,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual(3.5, self.evaluate(var)) - def _restore_replica_local_mean(self, save_path): + def _restore_replica_local_mean(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -917,13 +1108,13 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _restore_replica_local_sum(self, save_path): + def _restore_replica_local_sum(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -932,72 +1123,35 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): + save_path = self._save_replica_local_mean(distribution) + self._restore_replica_local_mean(save_path, distribution) - save_path = self._save_replica_local_mean() - self._restore_replica_local_mean(save_path) + def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): + save_path = self._save_replica_local_sum(distribution) + self._restore_replica_local_sum(save_path, distribution) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() - self._restore_replica_local_sum(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalMeanRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_mean() + def testSaveReplicaLocalMeanRestoreNormal(self, distribution): + save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalSumRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() + def testSaveReplicaLocalSumRestoreNormal(self, distribution): + save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - + def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() - self._restore_replica_local_mean(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + self._restore_replica_local_mean(save_path, distribution) + def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() - self._restore_replica_local_sum(save_path) - - def testTensorConversion(self): - with context.graph_mode(): - _, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) - converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) - - converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) - # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) + self._restore_replica_local_sum(save_path, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py index 5d57d144c1c16a08280970ecd89eb54f7cf1ffd4..b0bcf9b17456c938204a4892451928daf90b6743 100644 --- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -44,7 +44,9 @@ class WarmStartingUtilWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], save_with_distribution=[True, False], restore_with_distribution=[True, False], mode=["graph"])) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 60f6b90edcb71f04bca29b90744db201e83cd545..3079175015a9aee1625404902070df8f13b2089c 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -72,7 +72,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -80,6 +79,7 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/signal", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py index 29eeaf43c5185ce5519d4a1211f66e99ce61c6ab..ab3c07172a68255f4e387e071ac2f8341e93b90c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py @@ -82,7 +82,7 @@ class NormalTest(test.TestCase): x = constant_op.constant( [[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], [2.5, -2.5, -4.0, 0.0, 1.0, -2.0]], dtype=dtypes.float32) - s = math_ops.reduce_sum(x, reduction_indices=[1]) + s = math_ops.reduce_sum(x, axis=[1]) x = array_ops.transpose(x) # Reshape to shape (6, 2) n = constant_op.constant([6] * 2) prior = distributions.Normal(loc=mu0, scale=sigma0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index a60056c444a3fe7262939c5b3c73673f9a7c1469..cdee30bbc42e661952a9c757d7a30ebcd393f794 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -147,14 +147,13 @@ class WishartCholeskyTest(test.TestCase): x = chol_w.sample(10000, seed=42) self.assertAllEqual((10000, 3, 3), x.get_shape()) - moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval() + moment1_estimate = math_ops.reduce_mean(x, axis=[0]).eval() self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05) # The Variance estimate uses the squares rather than outer-products # because Wishart.Variance is the diagonal of the Wishart covariance # matrix. - variance_estimate = (math_ops.reduce_mean( - math_ops.square(x), reduction_indices=[0]) - + variance_estimate = (math_ops.reduce_mean(math_ops.square(x), axis=[0]) - math_ops.square(moment1_estimate)).eval() self.assertAllClose( chol_w.variance().eval(), variance_estimate, rtol=0.05) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 15c241d5d7a29d0e317cb6e5f46a40516e8a834f..74765f19e584c5de07c6aee4a36ec4e85438f862 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -168,7 +168,7 @@ class SoftmaxCentered(bijector.Bijector): # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + math_ops.reduce_logsumexp(x, axis=-1, keepdims=True)) return array_ops.squeeze( (-log_normalization + math_ops.reduce_sum( x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index aa680a92be64cf0f099acd335369f2a1610c5953..978e627d6638ddeea9df288d389354f0ac53d115 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -29,8 +29,8 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import spectral_ops from tensorflow.python.ops.distributions import util +from tensorflow.python.ops.signal import fft_ops __all__ = [ "auto_correlation", @@ -157,11 +157,11 @@ def auto_correlation( dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). - fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + fft_x_rotated_pad = fft_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). - shifted_product = spectral_ops.ifft(spectral_density) + shifted_product = fft_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 3aed121233be1268531495a2fa83fd323412e1fd..34614b86a75b93ab93cf844c645c211b1329c6d5 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -52,12 +52,6 @@ class Iterator(iterator_ops.EagerIterator): TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ - if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access - raise TypeError( - "`tf.data.experimental.prefetch_to_device()` is not compatible with " - "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " - "over the dataset instead.") - if not context.context().device_spec.device_type: is_remote_device = False else: diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 6a508fc6ba98740c4d441a064dc8a3e2b321f585..257d02057ae0d280074559aa9e97725bf5cc3fd0 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset -from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import unique from tensorflow.python.eager import test @@ -208,18 +207,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testTensorsExplicitPrefetchToDevice(self): - ds = Dataset.from_tensor_slices([0., 1.]) - ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) - - with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'): - datasets.Iterator(ds) - - for i, x in enumerate(ds): - with ops.device(test.gpu_device_name()): - x = math_ops.add(x, x) - self.assertEqual(float(i) + float(i), x.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 7949a3f6da293abdd85512209242bae76ab4d816..51443d24829bdc31a41813e0ff50ad7102422112 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,6 +22,7 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -164,8 +165,8 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, - **kwargs) + call_op = self.__call__( + dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 2dc196f550a10367066730f6f042c4ed69533ec3..e2154fcc5fcf774dcd52285d9442dfd5073a4992 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index 4b3cb624bc947a1d1956eff6accb6d4da3bf3b87..24f6b007b526b29157011f3b1e9abdbd50bacc8e 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -119,7 +119,8 @@ class DensenetBenchmark(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + (images, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, self.output_classes, diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py index 12b39b0cde49d4c017acfa74572c725036c54eff..e73841fbf724e05eaa3be90cc8650f795d3e1ccf 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -42,7 +42,8 @@ class MnistGraphGanBenchmark(tf.test.Benchmark): # Generate some random data. images_data = np.random.randn(batch_size, 784).astype(np.float32) dataset = tf.data.Dataset.from_tensors(images_data) - images = dataset.repeat().make_one_shot_iterator().get_next() + images = tf.compat.v1.data.make_one_shot_iterator( + dataset.repeat()).get_next() # Create the models and optimizers generator = mnist.Generator(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index ca27a85a229d41a85fa26ecdc982da478fe9e202..1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -470,7 +470,7 @@ "\n", " if epoch % 1 == 0:\n", " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset.make_one_shot_iterator():\n", + " for test_x in test_dataset:\n", " loss(compute_loss(model, test_x))\n", " elbo = -loss.result()\n", " display.clear_output(wait=False)\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 3acecd283cda83992bab0c37cf0b8037ed2cf27a..12c5eff2b4aa901bdab52bf545e95b1e4dce7468 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1184 +1,1174 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "name": "image_captioning_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" + "colab_type": "code", + "id": "U8l4RJ0XRPEm" + }, + "outputs": [], + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ] }, - "cells": [ - { - "metadata": { - "id": "K2s1A9eLRPEj", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "metadata": { - "id": "Cffg2i257iMS", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "QASbY_HGo4Lq", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "metadata": { - "id": "U8l4RJ0XRPEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "b6qbGw8MRPE5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "metadata": { - "id": "krQuPYTtRPE7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "aANEzb5WwSzg", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "metadata": { - "id": "4G3b8x8_RPFD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "mPBMgK34RPFL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(train_captions), len(all_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "8cSW4u-ORPFQ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "metadata": { - "id": "zXR0217aRPFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "MDvIu4sXRPFV", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "metadata": { - "id": "RD3vW4SsRPFW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "rERqlR3WRPGO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "metadata": { - "id": "Dx_fvbVgRPGQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "nyqH3zFwRPFi", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "metadata": { - "id": "HZfK8RhQRPFj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "oJGE34aiRPFo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b6qbGw8MRPE5" + }, + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "8Q44tNQVRPFt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", - "# putting token in the word2idx dictionary\n", - "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", - "tokenizer.word_index[''] = 0" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "krQuPYTtRPE7" + }, + "outputs": [], + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aANEzb5WwSzg" + }, + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "0fpJb5ojRPFv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "4G3b8x8_RPFD" + }, + "outputs": [], + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "olQArbgbRPF1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating a reverse mapping (index -> word)\n", - "index_word = {value:key for key, value in tokenizer.word_index.items()}" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "mPBMgK34RPFL" + }, + "outputs": [], + "source": [ + "len(train_captions), len(all_captions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8cSW4u-ORPFQ" + }, + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AidglIZVRPF4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "zXR0217aRPFR" + }, + "outputs": [], + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MDvIu4sXRPFV" + }, + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "gL0wkttkRPGA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RD3vW4SsRPFW" + }, + "outputs": [], + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rERqlR3WRPGO" + }, + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "M3CD75nDpvTI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Split the data into training and testing" - ] + "colab_type": "code", + "id": "Dx_fvbVgRPGQ" + }, + "outputs": [], + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nyqH3zFwRPFi" + }, + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "iS7DDMszRPGF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "HZfK8RhQRPFj" + }, + "outputs": [], + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "XmViPkRFRPGH", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "oJGE34aiRPFo" + }, + "outputs": [], + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "uEWM9xrYcg45", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] + "colab_type": "code", + "id": "8Q44tNQVRPFt" + }, + "outputs": [], + "source": [ + "tokenizer.word_index[''] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Q3TnZ1ToRPGV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "0fpJb5ojRPFv" + }, + "outputs": [], + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "SmZS2N0bXG3T", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AidglIZVRPF4" + }, + "outputs": [], + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "FDF_Nm3tRPGZ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "gL0wkttkRPGA" + }, + "outputs": [], + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M3CD75nDpvTI" + }, + "source": [ + "## Split the data into training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "nrvoDphgRPGd", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] + "colab_type": "code", + "id": "iS7DDMszRPGF" + }, + "outputs": [], + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AAppCGLKRPGd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "XmViPkRFRPGH" + }, + "outputs": [], + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uEWM9xrYcg45" + }, + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "ja2LFTMSdeV3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Q3TnZ1ToRPGV" + }, + "outputs": [], + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AZ7R1RxHRPGf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "SmZS2N0bXG3T" + }, + "outputs": [], + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "V9UbGQmERPGi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "FDF_Nm3tRPGZ" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nrvoDphgRPGd" + }, + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Qs_Sr03wRPGk", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AAppCGLKRPGd" + }, + "outputs": [], + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "-bYN7xA0RPGl", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "ja2LFTMSdeV3" + }, + "outputs": [], + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "PHod7t72RPGn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] + "colab_type": "code", + "id": "AZ7R1RxHRPGf" + }, + "outputs": [], + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Vt4WZ5mhJE-E", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "V9UbGQmERPGi" + }, + "outputs": [], + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "UlA4VIQpRPGo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Qs_Sr03wRPGk" + }, + "outputs": [], + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "1Wm83G-ZBPcC", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "-bYN7xA0RPGl" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PHod7t72RPGn" + }, + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "xGvOcLQKghXN", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] + "colab_type": "code", + "id": "Vt4WZ5mhJE-E" + }, + "outputs": [], + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "RCWpDtyNRPGs", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(index_word[predicted_id])\n", - "\n", - " if index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "UlA4VIQpRPGo" + }, + "outputs": [], + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "fD_y7PD6RPGt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "1Wm83G-ZBPcC" + }, + "outputs": [], + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xGvOcLQKghXN" + }, + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "io7ws3ReRPGv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RCWpDtyNRPGs" + }, + "outputs": [], + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " result.append(tokenizer.index_word[predicted_id])\n", + "\n", + " if tokenizer.index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Rprk3HEvZuxb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] + "colab_type": "code", + "id": "fD_y7PD6RPGt" + }, + "outputs": [], + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "9Psd1quzaAWg", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "io7ws3ReRPGv" + }, + "outputs": [], + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rprk3HEvZuxb" + }, + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, + "colab_type": "code", + "id": "9Psd1quzaAWg" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VJZXyJco6uLO" + }, + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ { - "metadata": { - "id": "VJZXyJco6uLO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 } - ] + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad42752144243ae3da61b955b31398cba846e..d412b25b368260b81256fd58034330b884261b2b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 480777d948769b56ac1cc3be2052fe48459e98d6..66d52a74943d0d81fde05ce51b019558b327978d 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -768,7 +768,7 @@ }, "outputs": [], "source": [ - "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -781,7 +781,7 @@ }, "outputs": [], "source": [ - "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -794,7 +794,7 @@ }, "outputs": [], "source": [ - "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -808,7 +808,7 @@ "outputs": [], "source": [ "# wrong translation\n", - "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index f3bb978875e226f58d6a00e09154191673a97415..fb7975d8fe867711cff31d627788a2d62a520aa9 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -142,7 +142,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + images, labels = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = resnet50.ResNet50(data_format()) logits = model(images, training=True) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index b702e91f92220c2a9003a1b82411131332012a9e..9585f3565f83af724b6336e466d3671443ba2361 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -72,14 +72,11 @@ def main(_): train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) + acc_test, loss_test = evaluate(model, ds_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() - it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - acc_validation, loss_validation = evaluate(model, it_validation) + acc_train, loss_train = evaluate(model, ds_train_one_shot) + acc_validation, loss_validation = evaluate(model, ds_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " @@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return logits, loss -def evaluate(model, iterator): +def evaluate(model, dataset): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in dataset: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py index 63b5c4c54d13e9c2448ec1f572ca1389f2443bef..770484abed96e540cf75cc5368a1410c31a8d2d0 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -82,7 +82,7 @@ class PTBBenchmark(tf.test.Benchmark): tf.ones( [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)).repeat(num_iters + num_warmup) - inputs = dataset.make_one_shot_iterator().get_next() + inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() with tf.device(tf.test.gpu_device_name()): outputs = model(inputs, training=True) @@ -124,7 +124,8 @@ class PTBBenchmark(tf.test.Benchmark): dtype=tf.int64)).repeat(num_iters + num_warmup) # inputs and labels have the same shape dataset = tf.data.Dataset.zip((dataset, dataset)) - (inputs, labels) = dataset.make_one_shot_iterator().get_next() + (inputs, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() with tf.device(tf.test.gpu_device_name()): optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c88c0f52eead58c7562cda1a49d164c1d857822d..566246de4957c1dc5919c10e22146706f9e50be8 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -354,9 +355,10 @@ class Mean(Metric): def write_summary_f(): summary_ops.scalar(name=self.name, tensor=t) return t - control_flow_ops.cond(write_summary, + smart_cond.smart_cond(write_summary, write_summary_f, - lambda: t) + lambda: t, + name="") return t diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 9d2d172752c7f3f3ee6eaa11ab8952313a4a3543..39e5957f5d1760613f2c33607c0bdb163040efb4 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -49,18 +49,6 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) - def testSummaryArg(self): - m = metrics.Mean() - m([1, 10, 100]) - m(1000) - m([10000.0, 100000.0]) - self.assertEqual(111111.0/6, m.result(write_summary=True).numpy()) - self.assertEqual(111111.0/6, m.result(write_summary=False).numpy()) - with self.assertRaises(ValueError): - m.result(write_summary=5) - with self.assertRaises(ValueError): - m.result(write_summary=[True]) - def testVariableCollections(self): with context.graph_mode(), ops.Graph().as_default(): m = metrics.Mean() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index f801d9a47b2f831a48d9b6335c69612c1356d800..5cc0c4f23d9d641ff1452c7cc9c1fcde612a33a2 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -24,7 +24,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -220,7 +220,7 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = keras_base_layer.get_default_graph_uid_map() + name_uid_map = base_layer_utils.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index f9c716360c5755ee1902b576545d776725f9966f..1d0d6c6c14ce4a8e454206e0be9fea4724f09192 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -115,6 +115,11 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. + + `Saver`'s name-based checkpointing strategy is fragile. Please switch to + `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more + robust object-based saving. These APIs will load checkpoints written by + `Saver`. """ def __init__(self, var_list): diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 4454abfb9667f824b9de0100bb81bae24ad5f7a6..8c35dddb5a515aa09cc70c173a9f0605e8567e82 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -87,8 +87,8 @@ class TFETest(test_util.TensorFlowTestCase): x += 1. # Without a device context, heuristics are used to place ops. # In this case, ops.reduce_mean runs on the GPU. - reduction_indices = range(x.shape.ndims) - m = math_ops.reduce_mean(x, reduction_indices) + axis = range(x.shape.ndims) + m = math_ops.reduce_mean(x, axis) # m is on GPU, bring it back to CPU and compare. self.assertEqual(3.5, m.cpu().numpy()) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 37f253d9c115ca4a6d3c30aca33ca1be12b4a927..a888379f13e79d1c246d4cd6d19a225c065692a2 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -16,7 +16,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":boosted_trees", - ":dnn", ":dnn_with_layer_annotations", ":early_stopping", ":expect_tensorflow_estimator_installed", @@ -25,7 +24,6 @@ py_library( ":extenders", ":head", ":hooks", - ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", @@ -47,18 +45,6 @@ py_library( ], ) -py_library( - name = "dnn", - srcs = ["python/estimator/dnn.py"], - srcs_version = "PY2AND3", - deps = [ - ":expect_tensorflow_estimator_installed", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:dnn", - ], -) - py_library( name = "dnn_with_layer_annotations", srcs = ["python/estimator/dnn_with_layer_annotations.py"], @@ -144,17 +130,6 @@ py_library( ], ) -py_library( - name = "linear", - srcs = ["python/estimator/linear.py"], - srcs_version = "PY2AND3", - deps = [ - ":expect_tensorflow_estimator_installed", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:linear", - ], -) - py_library( name = "logit_fns", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 80d59627620b86b5ebc20e1631ca368a0f2f6fdf..7d61247e7ef26d3777843cd3be20684583e9058c 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -58,8 +58,6 @@ _allowed_symbols = [ 'multi_label_head', 'poisson_regression_head', 'regression_head', - 'DNNEstimator', - 'LinearEstimator', 'boosted_trees_classifier_train_in_memory', 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index f384d761a8430074f022c973d7ec3d46cd90f70b..3eb396a29ccdc0478384f9fa122465731740a30d 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,7 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 1ab5418fe4659cb0068ee8c3ca1442f6f723ee76..2f7cd131d3ed20df307ed231cce2ecb50ecfbceb 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,7 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index bbe335be3e1384e8a86872165a4e37230e28b6c9..1cd83bdb5de7c2f6dc91c980750b49aca1a7790b 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,6 +14,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", + ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -32,7 +33,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", ], ) @@ -51,7 +52,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -69,7 +70,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], ) @@ -89,7 +90,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/feature_column:feature_column_py", ], ) @@ -110,7 +111,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index dd6da35ed009c07ad3819e7860a283c7837c1f83..9b3a5c58aaa9498257fc971ac60b97f31d5185d8 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -222,10 +222,8 @@ def sequence_categorical_column_with_identity( ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) + fc._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -265,10 +263,8 @@ def sequence_categorical_column_with_hash_bucket( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_hash_bucket( - key=key, - hash_bucket_size=hash_bucket_size, - dtype=dtype)) + fc._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -324,7 +320,7 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_file( + fc._categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -384,7 +380,7 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: if `dtype` is not integer or string. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_list( + fc._categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py index d8ca363627eace15e039679545366648df174c33..bcc25b8de895a769f9e11b207c2092e23d029b1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -53,19 +53,20 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase): return example def _build_feature_columns(self): - col = fc.categorical_column_with_identity( - 'int_ctx', num_buckets=100) + col = fc._categorical_column_with_identity('int_ctx', num_buckets=100) ctx_cols = [ - fc.embedding_column(col, dimension=10), - fc.numeric_column('float_ctx')] + fc._embedding_column(col, dimension=10), + fc._numeric_column('float_ctx') + ] identity_col = sfc.sequence_categorical_column_with_identity( 'int_list', num_buckets=10) bucket_col = sfc.sequence_categorical_column_with_hash_bucket( 'bytes_list', hash_bucket_size=100) seq_cols = [ - fc.embedding_column(identity_col, dimension=10), - fc.embedding_column(bucket_col, dimension=20)] + fc._embedding_column(identity_col, dimension=10), + fc._embedding_column(bucket_col, dimension=20) + ] return ctx_cols, seq_cols @@ -148,8 +149,8 @@ class SequenceExampleParsingTest(test.TestCase): """ example = _make_sequence_example() columns = [ - fc.categorical_column_with_identity('int_ctx', num_buckets=100), - fc.numeric_column('float_ctx'), + fc._categorical_column_with_identity('int_ctx', num_buckets=100), + fc._numeric_column('float_ctx'), col_fn(col_name, col_arg) ] context, seq_features = parsing_ops.parse_single_sequence_example( diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 2163af0b43864c96483df529f07881f2f985a80e..d5f74028298ee7015f5b2e3aaee7d9330c1acac1 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib from tensorflow.python.feature_column.feature_column import _LazyBuilder from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -109,13 +110,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=embedding_dimension_a, + embedding_column_a = fc._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc.embedding_column( - categorical_column_b, dimension=embedding_dimension_b, + embedding_column_b = fc._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -148,10 +151,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -206,7 +208,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) @@ -244,11 +246,11 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( + categorical_column_b = fc._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( @@ -315,10 +317,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc.indicator_column(categorical_column_b) + indicator_column_b = fc._indicator_column(categorical_column_b) input_layer, sequence_length = sfc.sequence_input_layer( features={ 'aaa': sparse_input_a, @@ -342,9 +344,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -530,7 +532,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) input_layer, _ = sfc.sequence_input_layer( features={'aaa': sparse_input}, feature_columns=[indicator_column]) @@ -616,8 +618,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -639,7 +640,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -918,8 +919,9 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=embedding_dimension, + embedding_column = fc._embedding_column( + categorical_column, + dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( @@ -956,8 +958,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -984,8 +985,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) @@ -1055,7 +1055,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -1101,7 +1101,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1152,7 +1152,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1218,7 +1218,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1250,7 +1250,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1277,7 +1277,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 67ffb939663358b5e356b3b626978db959c1bac9..0d34ad161855476b6a4cd9a258521dbe122b4140 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -26,7 +26,7 @@ import collections from tensorflow.python.feature_column import feature_column as fc_old -from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -226,10 +226,8 @@ def sequence_categorical_column_with_identity( ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) + fc_old._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -269,10 +267,8 @@ def sequence_categorical_column_with_hash_bucket( ValueError: `dtype` is neither string nor integer. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_hash_bucket( - key=key, - hash_bucket_size=hash_bucket_size, - dtype=dtype)) + fc_old._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -328,7 +324,7 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `dtype` is neither string nor integer. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_vocabulary_file( + fc_old._categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -388,7 +384,7 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: if `dtype` is not integer or string. """ return fc_old._SequenceCategoricalColumn( - fc_old.categorical_column_with_vocabulary_list( + fc_old._categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, @@ -441,7 +437,7 @@ def sequence_numeric_column( ValueError: if any dimension in shape is not a positive integer. ValueError: if `dtype` is not convertible to `tf.float32`. """ - shape = fc._check_shape(shape=shape, key=key) + shape = fc_old._check_shape(shape=shape, key=key) if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py index 5ecd85807c55e592f2216dbe2ff76f56e5c2650d..ca4398a142065de0be7bee57cd7e54670bbae12e 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -25,7 +25,7 @@ import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc from tensorflow.python.feature_column import feature_column as fc_old -from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _LazyBuilder from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -111,13 +111,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( - categorical_column_a, dimension=embedding_dimension_a, + embedding_column_a = fc_old._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc_old.embedding_column( - categorical_column_b, dimension=embedding_dimension_b, + embedding_column_b = fc_old._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -150,9 +152,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( + embedding_column_a = fc_old._embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( @@ -208,7 +210,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) @@ -246,11 +248,11 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc_old.categorical_column_with_identity( + categorical_column_b = fc_old._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( @@ -317,10 +319,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc_old.indicator_column(categorical_column_b) + indicator_column_b = fc_old._indicator_column(categorical_column_b) input_layer, sequence_length = sfc.sequence_input_layer( features={ 'aaa': sparse_input_a, @@ -344,9 +346,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc_old.categorical_column_with_identity( + categorical_column_a = fc_old._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -532,7 +534,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) input_layer, _ = sfc.sequence_input_layer( features={'aaa': sparse_input}, feature_columns=[indicator_column]) @@ -618,7 +620,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc_old.embedding_column( + embedding_column_a = fc_old._embedding_column( categorical_column_a, dimension=2) with self.assertRaisesRegexp( @@ -641,7 +643,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc_old.indicator_column(categorical_column_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -920,8 +922,9 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=embedding_dimension, + embedding_column = fc_old._embedding_column( + categorical_column, + dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( @@ -958,8 +961,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=2) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -986,8 +988,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc_old.embedding_column( - categorical_column, dimension=2) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) @@ -1057,7 +1058,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -1103,7 +1104,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1154,7 +1155,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc_old.shared_embedding_columns( + shared_embedding_columns = fc.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1220,7 +1221,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1252,7 +1253,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc_old.indicator_column(categorical_column) + indicator_column = fc_old._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index cd747df4d69d2c264f5a64b491da9570b1423770..dad50a3a73085526f65bd87c3d8549ceb75b3af4 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -47,6 +47,11 @@ tf_custom_op_py_library( ":variable_ops_op_lib", ], srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], deps = [ ":gen_variable_ops", "//tensorflow/contrib/util:util_py", @@ -66,6 +71,7 @@ tf_custom_op_py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", "//tensorflow/python:smart_cond", + "//tensorflow/python:sort_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -311,17 +317,3 @@ py_test( "//third_party/py/numpy", ], ) - -py_test( - name = "sort_ops_test", - size = "medium", - srcs = ["python/ops/sort_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 1921a77c1e96ee3531d1ed0f98e41c27c9d427ac..42184a4e55e292f7921702e3f8909ae54f717702 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -22,173 +22,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +from tensorflow.python.ops import sort_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops as framework_ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops - - -def sort(values, axis=-1, direction='ASCENDING', name=None): - """Sorts a tensor. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - name: Optional name for the operation. - - Returns: - A `Tensor` with the same dtype and shape as `values`, with the elements - sorted along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - with framework_ops.name_scope(name, 'sort'): - return _sort_or_argsort(values, axis, direction, return_argsort=False) - - -def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): - """Returns the indices of a tensor that give its sorted order along an axis. - - For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to - `tf.sort(values)`. For higher dimensions, the output has the same shape as - `values`, but along the given axis, values represent the index of the sorted - element in that slice of the tensor at the given position. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - stable: If True, equal elements in the original tensor will not be - re-ordered in the returned order. Unstable sort is not yet implemented, - but will eventually be the default for performance reasons. If you - require a stable order, pass `stable=True` for forwards compatibility. - name: Optional name for the operation. - - Returns: - An int32 `Tensor` with the same shape as `values`. The indices that would - sort each slice of the given `values` along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - del stable # Unused. - with framework_ops.name_scope(name, 'argsort'): - return _sort_or_argsort(values, axis, direction, return_argsort=True) - - -def _sort_or_argsort(values, axis, direction, return_argsort): - """Internal sort/argsort implementation. - - Args: - values: The input values. - axis: The axis along which to sort. - direction: 'ASCENDING' or 'DESCENDING'. - return_argsort: Whether to return the argsort result. - - Returns: - Either the sorted values, or the indices of the sorted values in the - original tensor. See the `sort` and `argsort` docstrings. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error - - values = framework_ops.convert_to_tensor(values, name='values') - - return _SORT_IMPL[direction](values, axis_static, return_argsort) - - -def _descending_sort(values, axis, return_argsort=False): - """Sorts values in reverse using `top_k`. - - Args: - values: Tensor of numeric values. - axis: Index of the axis which values should be sorted along. - return_argsort: If False, return the sorted values. If True, return the - indices that would sort the values. - - Returns: - The sorted values. - """ - k = array_ops.shape(values)[axis] - rank = array_ops.rank(values) - static_rank = values.shape.ndims - # Fast path: sorting the last axis. - if axis == -1 or axis + 1 == values.get_shape().ndims: - top_k_input = values - transposition = None - else: - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Calculate the actual axis index if counting from the end. Use the static - # rank if available, or else make the axis back into a tensor. - axis += static_rank or rank - if static_rank is not None: - # Prefer to calculate the transposition array in NumPy and make it a - # constant. - transposition = constant_op.constant( - np.r_[ - # Axes up to axis are unchanged. - np.arange(axis), - # Swap axis and rank - 1. - [static_rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - np.arange(axis + 1, static_rank - 1), - # Swap axis and rank - 1. - [axis]], - name='transposition') - else: - # Generate the transposition array from the tensors. - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - - values, indices = nn_ops.top_k(top_k_input, k) - return_value = indices if return_argsort else values - if transposition is not None: - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return_value = array_ops.transpose(return_value, transposition) - return return_value - - -def _ascending_sort(values, axis, return_argsort=False): - # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis, return_argsort) - # If not argsort, negate the values again. - return values_or_indices if return_argsort else -values_or_indices - - -_SORT_IMPL = { - 'ASCENDING': _ascending_sort, - 'DESCENDING': _descending_sort, -} +sort = sort_ops.sort +argsort = sort_ops.argsort diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 219cc199d79eca8c263859ae46bbb1ce0b4442b3..3593b501bb738b8f58dce4e40cffbdf410f136b3 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -113,7 +113,8 @@ class GANEstimator(estimator.Estimator): add_summaries=None, use_loss_summaries=True, config=None, - warm_start_from=None): + warm_start_from=None, + is_chief=True): """Initializes a GANEstimator instance. Args: @@ -154,6 +155,8 @@ class GANEstimator(estimator.Estimator): config: `RunConfig` object to configure the runtime settings. warm_start_from: A filepath to a checkpoint or saved model, or a WarmStartSettings object to configure initialization. + is_chief: Whether or not this Estimator is running on a chief or worker. + Needs to be set appropriately if using SyncReplicasOptimizers. Raises: ValueError: If loss functions aren't callable. @@ -187,7 +190,7 @@ class GANEstimator(estimator.Estimator): return _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn, use_loss_summaries) + get_hooks_fn, use_loss_summaries, is_chief) super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, @@ -215,7 +218,7 @@ def _get_gan_model( def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None, use_loss_summaries=True): + get_hooks_fn=None, use_loss_summaries=True, is_chief=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( @@ -236,7 +239,7 @@ def _get_estimator_spec( else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, get_hooks_fn) + gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) return estimator_spec @@ -321,11 +324,11 @@ def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, def _get_train_estimator_spec( gan_model, gan_loss, generator_optimizer, discriminator_optimizer, - get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops): + get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True): """Return an EstimatorSpec for the train case.""" scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, - discriminator_optimizer) + discriminator_optimizer, is_chief=is_chief) training_hooks = get_hooks_fn(train_ops) return model_fn_lib.EstimatorSpec( loss=scalar_loss, diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 3d6bdab0ad7b4778edf0776f2d1b6a6f105cf2c7..bc9021050bc010ce75c3091fef868549686c0e90 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,6 +48,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import input as input_lib from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -82,7 +83,7 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): self.assertEqual(generator_inputs, gan_model.generator_inputs) self.assertIsNotNone(gan_model.generated_data) - self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer + self.assertLen(gan_model.generator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.generator_fn) if mode == model_fn_lib.ModeKeys.PREDICT: self.assertIsNone(gan_model.real_data) @@ -95,7 +96,7 @@ class GetGANModelTest(test.TestCase, parameterized.TestCase): self.assertIsNotNone(gan_model.real_data) self.assertIsNotNone(gan_model.discriminator_real_outputs) self.assertIsNotNone(gan_model.discriminator_gen_outputs) - self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer self.assertIsNotNone(gan_model.discriminator_scope) self.assertIsNotNone(gan_model.discriminator_fn) @@ -121,6 +122,7 @@ def get_dummy_gan_model(): def dummy_loss_fn(gan_model, add_summaries=True): + del add_summaries return math_ops.reduce_sum(gan_model.discriminator_real_outputs - gan_model.discriminator_gen_outputs) @@ -168,6 +170,35 @@ class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar self.assertIsNotNone(spec.eval_metric_ops) + def test_get_sync_estimator_spec(self): + """Make sure spec is loaded with sync hooks for sync opts.""" + + def get_sync_optimizer(): + return sync_replicas_optimizer.SyncReplicasOptimizer( + training.GradientDescentOptimizer(learning_rate=1.0), + replicas_to_aggregate=1) + + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + + spec = estimator._get_estimator_spec( + model_fn_lib.ModeKeys.TRAIN, + self._gan_model, + generator_loss_fn=dummy_loss_fn, + discriminator_loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=g_opt, + discriminator_optimizer=d_opt) + + self.assertLen(spec.training_hooks, 4) + sync_opts = [ + hook._sync_optimizer for hook in spec.training_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + # TODO(joelshor): Add pandas test. class GANEstimatorIntegrationTest(test.TestCase): diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index e2594faf85bcf91cbe09f266e4d4211d20bdee17..364fa4eb461c62784803f0c309e3b7c5855df199 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -64,6 +64,9 @@ def condition_tensor(tensor, conditioning): """ tensor.shape[1:].assert_is_fully_defined() num_features = tensor.shape[1:].num_elements() + if conditioning.shape.ndims < 2: + raise ValueError('conditioning must be at least 2D, but saw shape: %s' + % conditioning.shape) mapped_conditioning = layers.linear( layers.flatten(conditioning), num_features) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py index 0aad769793761be69ee9d1e3416e44c7b3d8cea0..f5c7d53cf2c9aa08ba0074950983ef3ecd90168b 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase): array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, 1))) - with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'): + with self.assertRaisesRegexp(ValueError, 'at least 2D'): conditioning_utils.condition_tensor( array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5))) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index df0342c80c587cd0dfbf5f1455e05c31745995f5..a0a86c6337eefa756a209635faa70db686a36247 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -36,7 +36,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib from tensorflow.python.framework import ops @@ -47,7 +46,6 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.distributions import distribution as ds from tensorflow.python.ops.losses import losses from tensorflow.python.ops.losses import util from tensorflow.python.summary import summary @@ -740,11 +738,16 @@ def least_squares_discriminator_loss( def _validate_distributions(distributions): if not isinstance(distributions, (list, tuple)): raise ValueError('`distributions` must be a list or tuple. Instead, ' - 'found %s.', type(distributions)) + 'found %s.' % type(distributions)) for x in distributions: - if not isinstance(x, ds.Distribution): + # We used to check with `isinstance(x, tf.distributions.Distribution)`. + # However, distributions have migrated to `tfp.distributions.Distribution`, + # which is a new code repo, so we can't check this way anymore until + # TF-GAN is migrated to a new repo as well. + # This new check is not sufficient, but is a useful heuristic for now. + if not callable(getattr(x, 'log_prob', None)): raise ValueError('`distributions` must be a list of `Distributions`. ' - 'Instead, found %s.', type(x)) + 'Instead, found %s.' % type(x)) def _validate_information_penalty_inputs( @@ -817,7 +820,7 @@ def _numerically_stable_global_norm(tensor_list): Returns: A scalar tensor with the global norm. """ - if np.all([x is None for x in tensor_list]): + if all(x is None for x in tensor_list): return 0.0 list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index b9ac1bf15138c7e7d15ab3ebdac605d84921b6e5..969b68449d9c82f9f9144a8657cd8932b38fd0f7 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -213,7 +213,8 @@ class GANTrainOps( collections.namedtuple('GANTrainOps', ( 'generator_train_op', 'discriminator_train_op', - 'global_step_inc_op' + 'global_step_inc_op', + 'train_hooks' ))): """GANTrainOps contains the training ops. @@ -221,8 +222,17 @@ class GANTrainOps( generator_train_op: Op that performs a generator update step. discriminator_train_op: Op that performs a discriminator update step. global_step_inc_op: Op that increments the shared global step. + train_hooks: a list or tuple containing hooks related to training that need + to be populated when training ops are instantiated. Used primarily for + sync hooks. """ + def __new__(cls, generator_train_op, discriminator_train_op, + global_step_inc_op, train_hooks=()): + return super(GANTrainOps, cls).__new__(cls, generator_train_op, + discriminator_train_op, + global_step_inc_op, train_hooks) + class GANTrainSteps( collections.namedtuple('GANTrainSteps', ( diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 7ee39f304ab213a8fa4e7a6f03cda88037bff9a1..4c7bee41b33ce1fee46d374ca5fd1c0b603762f9 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -114,7 +114,7 @@ def gan_model( discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) + real_data = _convert_tensor_or_l_or_d(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: @@ -924,6 +924,7 @@ def gan_train_ops( generator_optimizer, discriminator_optimizer, check_for_unused_update_ops=True, + is_chief=True, # Optional args to pass directly to the `create_train_op`. **kwargs): """Returns GAN train ops. @@ -939,6 +940,8 @@ def gan_train_ops( discriminator_optimizer: The optimizer for the discriminator updates. check_for_unused_update_ops: If `True`, throws an exception if there are update ops outside of the generator or discriminator scopes. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. **kwargs: Keyword args to pass directly to `training.create_train_op` for both the generator and discriminator train op. @@ -980,6 +983,9 @@ def gan_train_ops( kwargs, model.generator_scope.name, model.discriminator_scope.name, check_for_unused_update_ops) + # Get the sync hooks if these are needed. + sync_hooks = [] + generator_global_step = None if isinstance(generator_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): @@ -995,6 +1001,7 @@ def gan_train_ops( trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) gen_update_ops += [generator_global_step.assign(global_step)] + sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief)) with ops.name_scope('generator_train'): gen_train_op = training.create_train_op( total_loss=loss.generator_loss, @@ -1016,6 +1023,7 @@ def gan_train_ops( trainable=False, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) dis_update_ops += [discriminator_global_step.assign(global_step)] + sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief)) with ops.name_scope('discriminator_train'): disc_train_op = training.create_train_op( total_loss=loss.discriminator_loss, @@ -1025,7 +1033,8 @@ def gan_train_ops( update_ops=dis_update_ops, **kwargs) - return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc) + return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc, + sync_hooks) # TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive @@ -1066,13 +1075,24 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): train_steps.generator_train_steps) discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, train_steps.discriminator_train_steps) - return [generator_hook, discriminator_hook] + return [generator_hook, discriminator_hook] + list(train_ops.train_hooks) return get_hooks +def _num_joint_steps(train_steps): + g_steps = train_steps.generator_train_steps + d_steps = train_steps.discriminator_train_steps + # Get the number of each type of step that should be run. + num_d_and_g_steps = min(g_steps, d_steps) + num_g_steps = g_steps - num_d_and_g_steps + num_d_steps = d_steps - num_d_and_g_steps + + return num_d_and_g_steps, num_g_steps, num_d_steps + + def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a hooks function for sequential GAN training. + """Returns a hooks function for joint GAN training. When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON ALL OPTIMIZERS TO AVOID RACE CONDITIONS. @@ -1105,12 +1125,7 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ - g_steps = train_steps.generator_train_steps - d_steps = train_steps.discriminator_train_steps - # Get the number of each type of step that should be run. - num_d_and_g_steps = min(g_steps, d_steps) - num_g_steps = g_steps - num_d_and_g_steps - num_d_steps = d_steps - num_d_and_g_steps + num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps) def get_hooks(train_ops): g_op = train_ops.generator_train_op @@ -1120,7 +1135,7 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): g_hook = RunTrainOpsHook(g_op, num_g_steps) d_hook = RunTrainOpsHook(d_op, num_d_steps) - return [joint_hook, g_hook, d_hook] + return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks) return get_hooks diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 64d670619905a427a84bee4b661228abca591fae..841f25cd7f1852767776eed2dcbf2522d8b0743b 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -519,7 +519,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): """Test output type.""" loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.GANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('cyclegan', create_cyclegan_model), @@ -528,7 +528,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): def test_cyclegan_output_type(self, get_gan_model_fn): loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('gan', create_gan_model, False), @@ -759,7 +759,7 @@ class TensorPoolAdjusteModelTest(test.TestCase): # For [pool_size, ?), the pool is full, tensor2 must be equal to some # historical values of tensor1 (which is previously stored in the # pool). - self.assertTrue(any([(v == t2).all() for v in history_values])) + self.assertTrue(any((v == t2).all() for v in history_values)) def _make_new_model_and_check(self, model, pool_size): pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) @@ -836,6 +836,9 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(train_ops, namedtuples.GANTrainOps) + # Make sure there are no training hooks populated accidentally. + self.assertEmpty(train_ops.train_hooks) + # TODO(joelshor): Add a test to check that custom update op is run. @parameterized.named_parameters( ('gan', create_gan_model, False), @@ -923,8 +926,15 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. - self.assertEqual(num_trainable_vars, - len(variables_lib.get_trainable_variables())) + self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) + + # Sync hooks should be populated in the GANTrainOps. + self.assertLen(train_ops.train_hooks, 2) + for hook in train_ops.train_hooks: + self.assertIsInstance( + hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) + sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) @@ -959,6 +969,32 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): coord.request_stop() coord.join(g_threads + d_threads) + @parameterized.named_parameters( + ('is_chief', True), + ('is_not_chief', False), + ) + def test_is_chief_in_train_hooks(self, is_chief): + """Make sure is_chief is propagated correctly to sync hooks.""" + model = create_gan_model() + loss = train.gan_loss(model) + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + train_ops = train.gan_train_ops( + model, + loss, + g_opt, + d_opt, + is_chief=is_chief, + summarize_gradients=True, + colocate_gradients_with_ops=True) + + self.assertLen(train_ops.train_hooks, 2) + for hook in train_ops.train_hooks: + self.assertIsInstance( + hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) + is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] + self.assertListEqual(is_chief_list, [is_chief, is_chief]) + class GANTrainTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train`.""" @@ -1036,6 +1072,44 @@ class GANTrainTest(test.TestCase, parameterized.TestCase): self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss) + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + train_ops = train.gan_train_ops( + model, + loss, + g_opt, + d_opt, + summarize_gradients=True, + colocate_gradients_with_ops=True) + + sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) + self.assertLen(sequential_train_hooks, 4) + sync_opts = [ + hook._sync_optimizer for hook in sequential_train_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + + joint_train_hooks = train.get_joint_train_hooks()(train_ops) + self.assertLen(joint_train_hooks, 5) + sync_opts = [ + hook._sync_optimizer for hook in joint_train_hooks if + isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] + self.assertLen(sync_opts, 2) + self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) + class PatchGANTest(test.TestCase, parameterized.TestCase): """Tests that functions work on PatchGAN style output.""" diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 94f522c04e5a09ed2d9355fa675125c340407923..fbccbead03fc0d641db40ede661bf3677d44c45d 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -170,6 +170,14 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); + // RendezvousMgr already aborted, shouldn't send RPC call any more + if (!call->status().ok()) { + done(call->status(), Args(), Args(), Tensor(), false); + session()->worker_cache->ReleaseWorker(src_worker, rwi); + delete call; + return; + } + // Start "call". Ref(); call->Start([this, call, src_worker, rwi, done]() { diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py index f7f1189bb93c611719186a697c40f371644f63a2..bc941ae9f23eaa5c46fcca95b9aba0ac0d87960a 100644 --- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -47,7 +48,7 @@ class SequenceFileDatasetTest(test.TestCase): dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py index bf398b838dfaaff6fdaf33a6cd7086ef13e43a3e..5c5599858ee6879a5703d65658bf4bbd881c7e72 100644 --- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -20,10 +20,9 @@ from __future__ import print_function from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class SequenceFileDataset(dataset_ops.DatasetSource): @@ -40,15 +39,12 @@ class SequenceFileDataset(dataset_ops.DatasetSource): For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() # Prints the (key, value) pairs inside a hadoop sequence file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: @@ -60,16 +56,10 @@ class SequenceFileDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_dataset_ops.sequence_file_dataset( - self._filenames, nest.flatten(self.output_types)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + self._filenames, self._element_structure._flat_types) # pylint: disable=protected-access @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index c7db0b77e25668fb8a42d204776044420f403e44..5a8c650fb927be0c835aaceffc516c048195c7bf 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -54,14 +54,12 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> for _ in range(3): ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) {'key': 1, 'val': {'NAME': b'WARM KITTY'}} {'key': 2, 'val': {'NAME': b'SOFT KITTY'}} @@ -74,23 +72,22 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> tf.enable_eager_execution() +>>> >>> dataset = IgniteDataset(cache_name="IMAGES") ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset.take(1): +>>> print(element) { - 'key': 'kitten.png', + 'key': 'kitten.png', 'val': { 'metadata': { 'file_name': b'kitten.png', 'label': b'little ball of fur', - width: 800, + width: 800, height: 600 - }, + }, 'pixels': [0, 0, 0, 0, ..., 0] } } @@ -100,13 +97,11 @@ jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels']) ->>> iterator = dataset.make_one_shot_iterator() ->>> next_obj = iterator.get_next() >>> ->>> with tf.Session() as sess: ->>> print(sess.run(next_obj)) +>>> for element in dataset: +>>> print(element) [0, 0, 0, 0, ..., 0] ``` @@ -126,18 +121,18 @@ Ignite Dataset allows using these two aspects of distributed neural network trai ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> +>>> >>> dataset = IgniteDataset("IMAGES") >>> >>> # Compute gradients locally on every worker node. ->>> gradients = [] +>>> gradients = [] >>> for i in range(5): >>> with tf.device("/job:WORKER/task:%d" % i): ->>> device_iterator = dataset.make_one_shot_iterator() +>>> device_iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) >>> device_next_obj = device_iterator.get_next() >>> gradient = compute_gradient(device_next_obj) ->>> gradients.append(gradient) ->>> +>>> gradients.append(gradient) +>>> >>> # Aggregate them on master node. >>> result_gradient = tf.reduce_sum(gradients) >>> @@ -145,7 +140,7 @@ Ignite Dataset allows using these two aspects of distributed neural network trai >>> print(sess.run(result_gradient)) ``` -High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. ### Distributed File System diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py index 936b29a4f50794380d48efed99e267c6b4c44dc6..e4762c91b193f9c5e32fa2642e702e61e8e5e57f 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py @@ -27,6 +27,7 @@ import six from tensorflow.contrib.ignite.python.ops import gen_dataset_ops from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -34,10 +35,7 @@ from tensorflow.python.framework import tensor_shape @six.add_metaclass(abc.ABCMeta) class Readable(object): - """Readable abstract class that exposes methods to do reading-related - - operations. - """ + """Abstract class that exposes methods to do reading-related operations.""" @abc.abstractmethod def __init__(self): @@ -227,10 +225,7 @@ types = { class TypeTreeNode(object): - """TypeTreeNode class exposes methods to format object tree structure - - data. - """ + """TypeTreeNode class exposes methods to format object tree structure data.""" def __init__(self, name, type_id, fields=None, permutation=None): """Constructs a new instance of TypeTreeNode. @@ -692,14 +687,14 @@ class IgniteClient(TcpClient): class IgniteDataset(dataset_ops.DatasetSource): - """Apache Ignite is a memory-centric distributed database, caching, and - - processing platform for transactional, analytical, and streaming workloads, - delivering in-memory speeds at petabyte scale. This contrib package - contains an integration between Apache Ignite and TensorFlow. The - integration is based on tf.data from TensorFlow side and Binary Client - Protocol from Apache Ignite side. It allows to use Apache Ignite as a - datasource for neural network training, inference and all other + """Apache Ignite is a memory-centric distributed database. + + It acts as a caching and processing platform for transactional, analytical, + and streaming workloads, delivering in-memory speeds at petabyte scale. + This contrib package contains an integration between Apache Ignite and + TensorFlow. The integration is based on tf.data from TensorFlow side and + Binary Client Protocol from Apache Ignite side. It allows to use Apache + Ignite as a datasource for neural network training, inference and all other computations supported by TensorFlow. Ignite Dataset is based on Apache Ignite Binary Client Protocol. """ @@ -756,6 +751,9 @@ class IgniteDataset(dataset_ops.DatasetSource): self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") + self._structure = structure.convert_legacy_structure( + self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), + self.cache_type.to_output_classes()) def _as_variant_tensor(self): return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port, @@ -763,13 +761,5 @@ class IgniteDataset(dataset_ops.DatasetSource): self.schema, self.permutation) @property - def output_classes(self): - return self.cache_type.to_output_classes() - - @property - def output_shapes(self): - return self.cache_type.to_output_shapes() - - @property - def output_types(self): - return self.cache_type.to_output_types() + def _element_structure(self): + return self._structure diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py index ef29b5f14a4b2fea2400ec4d56a7ad2cf44cf2cb..ff5d4c458c859fd8e5e3ae65ee41a454d55d6538 100644 --- a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py +++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.ignite import IgniteDataset from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -65,7 +66,7 @@ class IgniteDatasetTest(test.TestCase): self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"]) self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"]) - it = dataset.make_one_shot_iterator() + it = dataset_ops.make_one_shot_iterator(dataset) ne = it.get_next() with session.Session() as sess: diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index 478b716d88321101c971789f36c0ff8ecd3f418e..108da04494685f06f9afc26a26a5dadcdd99b0ff 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -115,7 +115,7 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, &tranformation_matrix]( + [&input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py index 24b790977dfdb675ff7bf0a119a08e243a30d3aa..ae9c7a611945e1445c933d74b9944054b3f0e0a4 100644 --- a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.image.python.ops import dense_image_warp from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes - +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients @@ -259,7 +259,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase): shape = [1, 2, 1, 1] msg = 'Should have raised an exception for invalid image size' - with self.assertRaises(ValueError, msg=msg): + with self.assertRaises(errors.InvalidArgumentError, msg=msg): self.check_interpolation_correctness(shape, 'float32', 'float32') diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 4997c31a7fc7f4243d03b22fc9c01fb13a2a25a4..ba5cdfebf92c07e496ed588848d5859ff6a5bff2 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -281,6 +281,13 @@ class ImageOpsTest(test_util.TensorFlowTestCase): value.eval(), np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype())) + @test_util.run_in_graph_and_eager_modes + def test_transform_eager(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + value = image_ops.transform(image, [1] * 8) + with self.test_session(use_gpu=True): + self.assertAllEqual(self.evaluate(value), np.array([[4, 4], [4, 4]])) + class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py index 9c7ada7afb7cb620c2e06887795053778f133287..f7ced440720209cb05dfcd79395c51517f9de0d5 100644 --- a/tensorflow/contrib/image/python/ops/dense_image_warp.py +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops @@ -60,28 +61,38 @@ def _interpolate_bilinear(grid, msg = 'Grid must be 4 dimensional. Received size: ' raise ValueError(msg + str(grid.get_shape())) - batch_size, height, width, channels = shape + batch_size, height, width, channels = (array_ops.shape(grid)[0], + array_ops.shape(grid)[1], + array_ops.shape(grid)[2], + array_ops.shape(grid)[3]) + + shape = [batch_size, height, width, channels] query_type = query_points.dtype grid_type = grid.dtype - if (query_points.shape.rank != 3 or - query_points.shape.dims[2].value != 2): - msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' - 'size: ') - raise ValueError(msg + str(query_points.get_shape())) - - _, num_queries, _ = query_points.get_shape().as_list() - - if height < 2 or width < 2: - msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' - raise ValueError(msg + str(grid.get_shape())) - - alphas = [] - floors = [] - ceils = [] - - index_order = [0, 1] if indexing == 'ij' else [1, 0] - unstacked_query_points = array_ops.unstack(query_points, axis=2) + with ops.control_dependencies([ + check_ops.assert_equal( + len(query_points.get_shape()), + 3, + message='Query points must be 3 dimensional.'), + check_ops.assert_equal( + array_ops.shape(query_points)[2], + 2, + message='Query points must be size 2 in dim 2.') + ]): + num_queries = array_ops.shape(query_points)[1] + + with ops.control_dependencies([ + check_ops.assert_greater_equal( + height, 2, message='Grid height must be at least 2.'), + check_ops.assert_greater_equal( + width, 2, message='Grid width must be at least 2.') + ]): + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) for dim in index_order: with ops.name_scope('dim-' + str(dim)): @@ -112,16 +123,18 @@ def _interpolate_bilinear(grid, alpha = array_ops.expand_dims(alpha, 2) alphas.append(alpha) - if batch_size * height * width > np.iinfo(np.int32).max / 8: - error_msg = """The image size or batch size is sufficiently large - that the linearized addresses used by array_ops.gather - may exceed the int32 limit.""" - raise ValueError(error_msg) - - flattened_grid = array_ops.reshape(grid, - [batch_size * height * width, channels]) - batch_offsets = array_ops.reshape( - math_ops.range(batch_size) * height * width, [batch_size, 1]) + with ops.control_dependencies([ + check_ops.assert_less_equal( + math_ops.cast(batch_size * height * width, dtype=dtypes.float32), + np.iinfo(np.int32).max / 8, + message="""The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""") + ]): + flattened_grid = array_ops.reshape( + grid, [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) # This wraps array_ops.gather. We reshape the image data such that the # batch, y, and x coordinates are pulled into the first dimension. @@ -182,7 +195,11 @@ def dense_image_warp(image, flow, name='dense_image_warp'): of dimensions. """ with ops.name_scope(name): - batch_size, height, width, channels = image.get_shape().as_list() + batch_size, height, width, channels = (array_ops.shape(image)[0], + array_ops.shape(image)[1], + array_ops.shape(image)[2], + array_ops.shape(image)[3]) + # The flow is defined on the image grid. Turn the flow into a list of query # points in the grid space. grid_x, grid_y = array_ops.meshgrid( diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index d4fb99a017faebe30384d739f22f4ff5fa986bc4..b25a6f7b5742917a032946fe03a0dab20e7dc1ad 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.contrib.image.ops import gen_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -271,8 +272,11 @@ def transform(images, raise TypeError("Images should have rank between 2 and 4.") if output_shape is None: - output_shape = tensor_util.constant_value( - array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + output_shape = array_ops.shape(images)[1:3] + if not context.executing_eagerly(): + output_shape_value = tensor_util.constant_value(output_shape) + if output_shape_value is not None: + output_shape = output_shape_value output_shape = ops.convert_to_tensor( output_shape, dtypes.int32, name="output_shape") diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 7129f09e8b42e48a9c768fd4a66cde3d4da9d31d..2b86331099ccae03664462987ee0c141d766c10f 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.contrib.kafka.python.ops import gen_dataset_ops from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class KafkaDataset(dataset_ops.DatasetSource): @@ -63,13 +63,5 @@ class KafkaDataset(dataset_ops.DatasetSource): self._group, self._eof, self._timeout) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index 3327a9f9a613bfb56e6a25af0fe1c0ca18609035..9e19884df852c0fd259a55aef56c62b4189cd1da 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,7 +20,7 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index 47cd01b924fb43e8a83836c58f8ced61e9e88268..3b9fa1b230b837a350d521c4165053c187786201 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -30,6 +30,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions from tensorflow.python.keras.utils.np_utils import normalize from tensorflow.python.keras.utils.np_utils import to_categorical from tensorflow.python.keras.utils.vis_utils import plot_model diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py index de7530231db4ea4f50996a67eb8c0d6936db9dd3..1626e55b9b3bc82bd96703bfab765ac6ad81f462 100644 --- a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py +++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py @@ -90,7 +90,7 @@ def _update_features_and_columns(features, feature_columns, mapped_column_name = column_name + "_MAPPED" # Construct new feature columns based on provided kernel_mappers. column_kernel_mappers = kernel_mappers_dict[feature_column] - new_dim = sum([mapper.output_dim for mapper in column_kernel_mappers]) + new_dim = sum(mapper.output_dim for mapper in column_kernel_mappers) mapped_columns.add( layers.feature_column.real_valued_column(mapped_column_name, new_dim)) diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py index 75806dbbeb1819bb0a6965bbc384e02df9895210..20395395281768ac429984a1e3552cfd187527a2 100644 --- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -20,9 +20,9 @@ from __future__ import print_function from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape class KinesisDataset(dataset_ops.DatasetSource): @@ -34,15 +34,12 @@ class KinesisDataset(dataset_ops.DatasetSource): For example, we can construct and use the KinesisDataset as follows: ```python + tf.enable_eager_execution() + dataset = tf.contrib.kinesis.KinesisDataset( "kinesis_stream_name", read_indefinitely=False) - next = dataset.make_one_shot_iterator().get_next() - with tf.Session() as sess: - while True: - try: - print(sess.run(nxt)) - except tf.errors.OutOfRangeError: - break + for element in dataset: + print(element) ``` Since Kinesis is a data streaming service, data may not be available @@ -84,13 +81,5 @@ class KinesisDataset(dataset_ops.DatasetSource): self._stream, self._shard, self._read_indefinitely, self._interval) @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.scalar() - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index e6596bfdfb9b153e5946ab7f8933c160cf2f2326..9ca6f8df5dbe3c236c4cd85095176ce69ad9deaa 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -78,6 +78,11 @@ tf_custom_op_py_library( ":sparse_feature_cross_op_op_lib", ], srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], deps = [ ":sparse_feature_cross_op", "//tensorflow/contrib/framework:framework_py", @@ -253,7 +258,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -277,7 +282,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 124515e5a6474f2cc1038830346e27277c6ceea7..295c721fceda6aaaf8672525ceed560308db6af7 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import itertools import math +import sys import numpy as np @@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): assert num_shards > 0 assert num_shards <= vocab_size - embedding_weights = partitioned_variables.create_partitioned_variables( + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32) + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[vocab_size, embed_dim], - slicing=[num_shards, 1], - initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=initializer)) for w in embedding_weights: w.initializer.run() embedding_weights = [w.eval() for w in embedding_weights] @@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights, sparse_ids, sparse_weights) +# pylint: disable=invalid-name +def local_variable_scope(): + """Create a variable scope named like the caller function.""" + return variable_scope.variable_scope(sys._getframe(1).f_code.co_name) +# pylint: enable=invalid-name + + class ScatteredEmbeddingLookupTest(test.TestCase): def setUp(self): @@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_scattered_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) @@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1]) def test_scattered_embedding_multiple_partition(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=7) values = constant_op.constant([4, 4, 5]) @@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertGreater(embedding_diff, 0) def test_scattered_embedding_coverage(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): size = 8 embedding_weights = self._random_weights(size=size, num_shards=3) values = constant_op.constant(["foo"]) @@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) def test_scattered_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][2]) def test_scattered_embedding_lookup_sparse(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=3) sparse_tensor = sparse_tensor_lib.SparseTensor( values=["foo", "bar", "foo", "bar"], @@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embeds = np.random.randn(n_embed, d_embed) idx = np.random.randint(0, n_embed, idx_shape) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): idx = np.random.randint(0, 5, 10) idx2d = np.random.randint(0, 5, (10, 2)) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_np2d = embeds[idx2d] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_hashed_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) # The first three sampled_candidates are equal, so the first three @@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][3]) def test_hashed_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -691,7 +704,6 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): index += num_val return grouped_vals - @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index f42112206d0db9d2e42bd4cff19f6a6533951d46..3671633c8d795034b13cb55fd6db87c453e9fa12 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -84,8 +84,7 @@ def bow_encoder(ids, if isinstance(ids, sparse_tensor.SparseTensor): raise TypeError('ids are expected to be dense Tensor, got: %s', ids) return math_ops.reduce_mean( - embedding_ops.embedding_lookup(embeddings, ids), - reduction_indices=1) + embedding_ops.embedding_lookup(embeddings, ids), axis=1) def embed_sequence(ids, diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 222404b19db2b93b695ee6d2cb16584e17033700..00d819ed0e9fe3a5644105a571beda100204631e 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1015,8 +1015,7 @@ class _OneHotColumn( dense_id_tensor, depth=self.length, on_value=1.0, off_value=0.0) # Reduce to get a multi-hot per example. - return math_ops.reduce_sum( - one_hot_id_tensor, reduction_indices=[output_rank - 1]) + return math_ops.reduce_sum(one_hot_id_tensor, axis=[output_rank - 1]) @property def _variable_shape(self): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 6fb4b9ff3534cab34c84de5d13fea7aff756556d..7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,7 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index d90d6ecf7f671a40a7ff2b066b6782c7421f9887..cab8da808b6413518ff4864cb0b03a42809260f1 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -27,7 +27,7 @@ import numpy as np from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.layers.python.layers import feature_column_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index ac9561c7693fc4ad994a00889aa3f15b4b5a5ee4..403b522ce45ac6ad98a321378626b87aaa7738aa 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import convolutional as convolutional_layers from tensorflow.python.layers import core as core_layers @@ -1958,7 +1959,7 @@ class GDN(base.Layer): self._reparam_offset = reparam_offset self.data_format = data_format self._channel_axis() # trigger ValueError early - self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5) + self.input_spec = input_spec.InputSpec(min_ndim=3, max_ndim=5) def _channel_axis(self): try: @@ -2015,7 +2016,7 @@ class GDN(base.Layer): raise ValueError('The channel dimension of the inputs to `GDN` ' 'must be defined.') self._input_rank = input_shape.ndims - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=input_shape.ndims, axes={ channel_axis: num_channels }) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 8ead6336a08db4dd52edf0d3372db5a50f860e2b..d791418c9d0f887058ceb535092fa8122da1aa75 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1459,13 +1459,6 @@ class DropoutTest(test.TestCase): class FlattenTest(test.TestCase): - def testInvalidRank(self): - with ops.Graph().as_default() as g, self.session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32) - inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): - _layers.flatten(inputs) - def testUnknownLastDim(self): with ops.Graph().as_default() as g, self.session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1502,6 +1495,12 @@ class FlattenTest(test.TestCase): images.get_shape().num_elements()) self.assertEqual(output.get_shape()[0], images.get_shape()[0]) + def testFlatten0D(self): + with self.cached_session(): + scalars = random_ops.random_uniform((5,), seed=1, name='scalars') + output = _layers.flatten(scalars) + self.assertEqual(output.shape, (5, 1)) + def testFlattenBatchSize(self): height, width = 3, 3 with self.cached_session() as sess: @@ -3811,7 +3810,7 @@ class UnitNormTests(test.TestCase): image = random_ops.random_uniform((height, width, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) shape = [height, width, 3] del shape[dim] @@ -3847,7 +3846,7 @@ class UnitNormTests(test.TestCase): image = array_ops.placeholder(dtypes.float32, (None, None, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), axis=dim)) with self.cached_session(): actual = norms.eval({image: placeholder_value}) diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py index 51faba30c74d64c54d3d2b11d2a11195cca6b759..5cb00b76847430be8ade9f4e4fc8f7372035485a 100644 --- a/tensorflow/contrib/layers/python/layers/regularizers_test.py +++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py @@ -141,7 +141,7 @@ class RegularizerTest(test.TestCase): dummy_regularizer = lambda x: math_ops.reduce_sum(2 * x) array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]] tensor_weights_list = [constant_op.constant(x) for x in array_weights_list] - expected = sum([2 * x for l in array_weights_list for x in l]) + expected = sum(2 * x for l in array_weights_list for x in l) with self.cached_session(): result = regularizers.apply_regularization(dummy_regularizer, tensor_weights_list) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 61185f65a9bd294003515456f891de0a68661a82..14065fcee51c014a1af227504eaaca1fa39941e1 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -24,6 +24,11 @@ py_library( exclude = ["python/learn/**/*_test.py"], ), srcs_version = "PY2AND3", + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//video/youtube/personalization:__subpackages__", + ], # This library should not depend on sklearn, even though some of the code # refers to it. (The code handles the presence of sklearn conditionally.) deps = [ @@ -269,6 +274,7 @@ py_test( name = "estimator_test", size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], + shard_count = 2, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index eabebb7e881558471c343c0573cc9a8f4a425312..10fbd60ba2df4c3f84169bf04f249d67dc14573f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -38,11 +37,12 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.training import training_util # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. @@ -150,10 +150,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=input_layer_partitioner) as input_layer_scope: - if all([ + if all( isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access for fc in feature_columns - ]): + ): net = layers.input_from_feature_columns( columns_to_tensors=features, feature_columns=feature_columns, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 3d85533d92d17095bae9a69f229171e1bf61ba10..2ade6b7b6ce2678ec8df7c98ffaa5636ae9d4b1d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -236,10 +236,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=input_layer_partitioner) as dnn_input_scope: - if all([ + if all( isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access for fc in dnn_feature_columns - ]): + ): net = layers.input_from_feature_columns( columns_to_tensors=features, feature_columns=dnn_feature_columns, @@ -292,8 +292,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): linear_parent_scope, values=tuple(six.itervalues(features)), partitioner=linear_partitioner) as scope: - if all([isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access - for fc in linear_feature_columns]): + if all(isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access + for fc in linear_feature_columns): if joint_linear_weights: linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns( columns_to_tensors=features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 4e65c180d8bee9ab8fe9b1fbf32edc229c31af09..d46a873bfaa297e7f6242aa56e9d0bf0eb551867 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 2bd57597c2e9444b51b1dacfbe4180b443c95a3d..ee25cebd484f1e831fe8b6d3aa7290da7558adee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 1d8a59281a4934ad063362cba064e6cb3abff5a2..28c4964527bb034c8c6b1642366c6c82c1a72201 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -668,7 +668,7 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): sequences = centers + noise inputs = array_ops.expand_dims(sequences, 2) - labels = math_ops.reduce_mean(sequences, reduction_indices=[1]) + labels = math_ops.reduce_mean(sequences, axis=[1]) return {'inputs': inputs}, labels return input_fn @@ -722,8 +722,8 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2) labels = math_ops.to_int32( array_ops.squeeze( - math_ops.reduce_sum( - inputs, reduction_indices=[1]) > (sequence_length / 2.0))) + math_ops.reduce_sum(inputs, axis=[1]) > ( + sequence_length / 2.0))) return {'inputs': inputs}, labels return input_fn diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 8bc869db895b753be805219892342b5e6ea3799b..9132b2209bce8005b323d058d6d176784a84b2d1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -1066,11 +1066,11 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): - saver_hook_exists = any([ + saver_hook_exists = any( isinstance(h, basic_session_run_hooks.CheckpointSaverHook) for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks + model_fn_ops.training_chief_hooks) - ]) + ) if not saver_hook_exists: chief_hooks = [ basic_session_run_hooks.CheckpointSaverHook( @@ -1493,7 +1493,7 @@ class Estimator(BaseEstimator): # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): """Scikit learn wrapper for TensorFlow Learn Estimator. - + THIS CLASS IS DEPRECATED. See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) for general migration instructions. diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index e100bc7a1e7be4896e9ab1c965775b5185b38897..9ee8d8004bf26224dd96a98bad109720c44d04f7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.contrib.linear_optimizer.python import sdca_optimizer -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -155,8 +155,8 @@ def _linear_model_fn(features, labels, mode, params, config=None): parent_scope, values=tuple(six.itervalues(features)), partitioner=partitioner) as scope: - if all([isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access - for fc in feature_columns]): + if all(isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access + for fc in feature_columns): if joint_weights: layer_fn = layers.joint_weighted_sum_from_feature_columns else: diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 597ca4e86dbf66c86182f14a2a364b662d52fb0a..dfc76bfde6c0109f98093232b6f223d6938007f9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.linear_optimizer.python import sdca_optimizer as sdca_optimizer_lib from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor @@ -1745,7 +1745,7 @@ class LinearRegressorTest(test.TestCase): 'place_holder': constant_op.constant([[0.0]] * num_examples), }, constant_op.constant( - [[1 if i % 4 is 0 else 0] for i in range(num_examples)]) + [[1 if i % 4 == 0 else 0] for i in range(num_examples)]) place_holder = feature_column_lib.real_valued_column('place_holder') sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 29552d24f1eaa0d85a99c8b09f69d007e7e4fe9f..59a67636ae275c5ca1df21685770baa7a960d667 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -27,7 +27,7 @@ from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_n from tensorflow.python.util.deprecation import deprecated -@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') +@deprecated(None, 'Use tf.compat.v1.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index b4ef055f5ae484ec704ad42efcf2c00c4a7a4f56..e9df7258a358d9543f2bb386518d900bd6ddef74 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -53,7 +53,7 @@ PANDAS_DTYPES = { } -@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') +@deprecated(None, 'Please use tf.compat.v1.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 8466dc36d13e223aed4f1dfe8e39a6f91c99fa55..d49834dc860a8b4341ddd3720fde52281f7474f7 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SdcaModel.""" +"""Tests for SdcaModel (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f3f1dcd98db5ae24af154d1f0851a0688d2bc611..c056a12fa5307a7e9ac4cf30e1386ddfd5cd7d75 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Proximal stochastic dual coordinate ascent optimizer for linear models.""" +# pylint: disable=line-too-long +"""Proximal stochastic dual coordinate ascent optimizer for linear models (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" +# pylint: enable=line-too-long from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -40,6 +47,7 @@ from tensorflow.python.ops import variables as var_ops from tensorflow.python.ops.nn import log_poisson_loss from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits from tensorflow.python.summary import summary +from tensorflow.python.util import deprecation __all__ = ['SdcaModel'] @@ -48,7 +56,7 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - Loss functions supported: + Loss functions supported: * Binary logistic loss * Squared loss @@ -109,6 +117,10 @@ class SdcaModel(object): ``` """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, examples, variables, options): """Create a new sdca optimizer.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a001555e8f257c88a52fdb40d4181f5cd9c92e84..a28394964a12013c43d85701b5a0ab5c559afd62 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sharded mutable dense hash table.""" +"""Sharded mutable dense hash table (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -28,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation # TODO(rohanj): This should subclass Checkpointable and implement @@ -45,6 +51,10 @@ class ShardedMutableDenseHashTable(object): # TODO(andreasst): consider moving this to lookup module + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, key_dtype, value_dtype, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py index 2b56d0fa3a8b8564b7c73a62bd99cc900d6f5c54..2d1457f9e4cc576da696be191e718814dd9ff4e5 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sharded_mutable_dense_hashtable.py.""" +"""Tests for sharded_mutable_dense_hashtable.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py index 003795233ff2b28e33fc10388ef25efb63c43bb0..64730f8eed1ff9bfcd4a980dceb28abb98e39f73 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sparse feature column.""" +"""Sparse feature column (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope +from tensorflow.python.util import deprecation class SparseFeatureColumn(object): @@ -68,6 +74,10 @@ class SparseFeatureColumn(object): @@feature_values """ + @deprecation.deprecated( + None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' + 'please check its latest version in core: ' + 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') def __init__(self, example_indices, feature_indices, feature_values): """Creates a `SparseFeatureColumn` representation. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py index 51c4f68543da2f563481cc2d35b556796616cf9d..0ae780e1a100c7dadde7196803f2ae0d4bcb2334 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for sparse_feature_column.py.""" +"""Tests for sparse_feature_column.py (deprecated). + +This module and all its submodules are deprecated. To UPDATE or USE linear +optimizers, please check its latest version in core: +tensorflow_estimator/python/estimator/canned/linear_optimizer/. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index 647667188238dc18b137eaad98356a79b3a549b4..7a5354222f103aa0f45adc513079e420bbbfd30c 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -524,7 +524,7 @@ class SDCALinearRegressorTest(test.TestCase): # LinearClassifier requires at least one column. 'place_holder': constant_op.constant([[0.0]] * num_examples), - }, constant_op.constant([[1 if i % 4 is 0 else 0] + }, constant_op.constant([[1 if i % 4 == 0 else 0] for i in range(num_examples)]) with self._single_threaded_test_session(): diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5e99ef460518fa761b12533e5dc07dc252f1d582..9b2c2dd87cc8a92fbb6b45504939be3788b60839 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -25,6 +25,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session from tensorflow.python.data.experimental.ops import counter +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -2737,7 +2738,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_scalar_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable(1.0) insert = table.insert(c, value) size = table.size() @@ -2758,7 +2759,7 @@ class MutableHashTableBenchmark(test.Benchmark): def benchmark_many_repeated_batch_32_insert_scalar(self): table = self._create_table() - c = counter.Counter().make_one_shot_iterator().get_next() + c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() value = variables.Variable([1.0] * 32) insert = table.insert(32 * c + list(range(32)), value) size = table.size() diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 619294b51822bd9983eda777acae5cf0d253926d..709a042bbcefb89125f7e4cd14a0d7ecd2b53281 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.python.compat import compat from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -60,41 +59,12 @@ def _scale_losses(losses, weights): """ # First, compute the sum of the losses over all elements: start_index = max(0, weights.get_shape().ndims) - reduction_indices = list(range(start_index, losses.get_shape().ndims)) - reduced_losses = math_ops.reduce_sum( - losses, reduction_indices=reduction_indices) + axis = list(range(start_index, losses.get_shape().ndims)) + reduced_losses = math_ops.reduce_sum(losses, axis=axis) reduced_losses = math_ops.multiply(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) -def _safe_div(numerator, denominator, name="value"): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - name: An optional name for the returned op. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator, name=name) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator), - name=name) - - def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -107,7 +77,7 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return _safe_div(total_loss, num_present, name="value") + return math_ops.div_no_nan(total_loss, num_present, name="value") @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -187,10 +157,9 @@ def _num_present(losses, weights, per_batch=False): # First, count the number of nonzero weights: if weights.get_shape().ndims >= 1: - reduction_indices = list(range(1, weights.get_shape().ndims)) + axis = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weights, 0)), - reduction_indices=reduction_indices) + math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis) # Next, determine the number of elements that weights would broadcast to: broadcast_dims = array_ops.slice( @@ -606,20 +575,20 @@ def mean_pairwise_squared_error(predictions, if weights.get_shape().ndims is None: raise ValueError("weights.get_shape().ndims cannot be None") - reduction_indices = list(range(1, diffs.get_shape().ndims)) + axis = list(range(1, diffs.get_shape().ndims)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), reduction_indices=reduction_indices) + math_ops.square(diffs), axis=axis) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch, - name="value") + term1 = 2.0 * math_ops.div_no_nan( + sum_squares_diff_per_batch, num_present_per_batch, name="value") - sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch), - name="value") + sum_diff = math_ops.reduce_sum(diffs, axis=axis) + term2 = 2.0 * math_ops.div_no_nan( + math_ops.square(sum_diff), + math_ops.square(num_present_per_batch), + name="value") loss = _scale_losses(term1 - term2, weights) @@ -674,7 +643,7 @@ def cosine_distance(predictions, radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ axis, ]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 0a07588f07f0bb89dbf5dc5909f511f74470fb41..2a5232b476712a96f84be0f4725beb78bc138297 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,11 +30,13 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow -# 1.10 branch does not work. `make distclean` fails and blocks the build -# process. For now we're hardcoding to the version which is used by -# TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" + +# Note: The protobuf repo needs to be cloned due to its submodules. +# These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, +# from which to clone it from and checkout to. +readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" +readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" + # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" @@ -91,11 +93,34 @@ download_and_extract() { find "${dir}" -type f -name '*BUILD' -delete } +function clone_repository() { + local repo_url="${1}" + local destination_directory="${2}" + local commit_sha="${3}" + + if [[ -d "${destination_directory}" ]]; then + rm -rf "${destination_directory}" + fi + + git clone "${repo_url}" "${destination_directory}" + + pushd "$(pwd)" 1>/dev/null + + cd "${destination_directory}" + + if [[ -n "${commit_sha}" ]]; then + git checkout "${PROTOBUF_TAG}" + fi + + git submodule update --init + + popd 1>/dev/null +} + download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" -download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" download_and_extract "${DOUBLE_CONVERSION_URL}" "${DOWNLOADS_DIR}/double_conversion" @@ -106,6 +131,8 @@ download_and_extract "${CUB_URL}" "${DOWNLOADS_DIR}/cub/external/cub_archive" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +clone_repository "${PROTOBUF_REPO}" "${DOWNLOADS_DIR}/protobuf" "${PROTOBUF_TAG}" + replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index e779eff68901af7042deb5c09b78a230e0d06d02..655c7eefcb978d40c8bc16a23685e03ed71bfb63 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -157,6 +157,7 @@ tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc +tensorflow/core/kernels/multinomial_op.cc tensorflow/core/kernels/no_op.cc tensorflow/core/kernels/non_max_suppression_op.cc tensorflow/core/kernels/one_hot_op.cc @@ -252,6 +253,7 @@ tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc tensorflow/core/kernels/strided_slice_op_inst_1.cc diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index ac1236086503a7c6e541bdf098efcb92f84e577f..9aabc4bec3053871e3ff6cd3a88fd76d293f48cc 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics_impl from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ @@ -175,7 +175,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200, return best_f1 best_f1 = distribution_strategy_context.get_replica_context().merge_call( - f1_across_replicas, values) + f1_across_replicas, args=(values,)) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], fn=update_ops['fn'], name='update') diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index d6a670f97b32a29129cb9ea0cd71c5a2b7597a47..e789d2cb9dfbac7b1e145be48b3f707af3fd4e18 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -291,12 +291,11 @@ class F1ScoreTest(test.TestCase): labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) - tf_predictions, tf_labels = (dataset_ops.Dataset - .from_tensor_slices((predictions, labels)) - .repeat() - .batch(batch_size) - .make_one_shot_iterator() - .get_next()) + tf_predictions, tf_labels = dataset_ops.make_one_shot_iterator( + dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size)).get_next() f1, f1_op = classification.f1_score(tf_labels, tf_predictions, num_thresholds=3) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d6932f6e4b603b1a76250ab622f5fe8eaea81bc9..7b432f8bd20989c6d95310bcaca88d44ce3e0d1f 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -24,7 +24,6 @@ from __future__ import print_function import collections as collections_lib -from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -46,32 +45,6 @@ from tensorflow.python.util.deprecation import deprecated _EPSILON = 1e-7 -def _safe_div(numerator, denominator): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator)) - - @deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' 'order of the labels and predictions arguments has been switched.') def streaming_true_positives(predictions, @@ -3247,24 +3220,20 @@ def streaming_covariance(predictions, # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) # batch_mean_prediction is E[x_B] in the update equation - batch_mean_prediction = _safe_div( - math_ops.reduce_sum(weighted_predictions), - batch_count) - delta_mean_prediction = _safe_div( - (batch_mean_prediction - mean_prediction) * batch_count, - update_count) + batch_mean_prediction = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_predictions), batch_count) + delta_mean_prediction = math_ops.div_no_nan( + (batch_mean_prediction - mean_prediction) * batch_count, update_count) update_mean_prediction = state_ops.assign_add(mean_prediction, delta_mean_prediction) # prev_mean_prediction is E[x_A] in the update equation prev_mean_prediction = update_mean_prediction - delta_mean_prediction # batch_mean_label is E[y_B] in the update equation - batch_mean_label = _safe_div( - math_ops.reduce_sum(weighted_labels), - batch_count) - delta_mean_label = _safe_div( - (batch_mean_label - mean_label) * batch_count, - update_count) + batch_mean_label = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_labels), batch_count) + delta_mean_label = math_ops.div_no_nan( + (batch_mean_label - mean_label) * batch_count, update_count) update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label @@ -3447,7 +3416,7 @@ def streaming_mean_cosine_distance(predictions, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) radial_diffs = math_ops.reduce_sum( - radial_diffs, reduction_indices=[ + radial_diffs, axis=[ dim, ], keepdims=True) mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, @@ -3926,9 +3895,8 @@ def cohen_kappa(labels, po_sum = math_ops.reduce_sum(po) total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( - _safe_div( - math_ops.to_double(pe_row * pe_col), - math_ops.to_double(total))) + math_ops.div_no_nan( + math_ops.to_double(pe_row * pe_col), math_ops.to_double(total))) po_sum, pe_sum, total = (math_ops.to_double(po_sum), math_ops.to_double(pe_sum), math_ops.to_double(total)) diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 1b0383d24c0c472b4875d15c3650e37dfd2439e1..c922d0cd11fda3c51a51ceccf69798df7ce75f26 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test def _GetExampleIter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - return dataset.make_one_shot_iterator() + return dataset_ops.make_one_shot_iterator(dataset) class FixedLossScaleManagerTest(test.TestCase): diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index 9009df0eefec13146090ba5fc2096e71ba6eb89d..33f9a43e803ea845a25bba284e41e5a0e6228dad 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -132,7 +132,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 opt = gd.GradientDescentOptimizer(lr) @@ -182,7 +182,7 @@ class LossScaleOptimizerTest(test.TestCase): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1]) - itr = dataset.make_one_shot_iterator() + itr = dataset_ops.make_one_shot_iterator(dataset) lr = 1 init_loss_scale = 8 diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py index f0ce6fe03966c2de2dfd8ebcca07bf46afcf4fce..1fa5c8cb485704a5fccc486e823bbc4050bf505a 100644 --- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -119,7 +120,7 @@ class _MaskedConv(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(ndim=self.rank + 2) + self.input_spec = input_spec.InputSpec(ndim=self.rank + 2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -171,7 +172,7 @@ class _MaskedConv(base.Layer): dtype=self.dtype) else: self.bias = None - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=self.rank + 2, axes={channel_axis: input_dim}) self.built = True @@ -393,14 +394,14 @@ class MaskedFullyConnected(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(min_ndim=2) + self.input_spec = input_spec.InputSpec(min_ndim=2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( min_ndim=2, axes={-1: tensor_shape.dimension_value(input_shape[-1])}) self.kernel = self.add_variable( diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 6c203e5519e6a66d20e2509eca3c74eb66bf32c7..fa1a7aaff0aa59a6a64b1f0bf836a273926d785d 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import optimizer from tensorflow.python.training import saver from tensorflow.python.training import session_run_hook +from tensorflow.python.training.saving import saveable_object_util LOCAL_VARIABLE_NAME = 'local_center_variable' GLOBAL_VARIABLE_NAME = 'global_center_variable' @@ -424,7 +425,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.trainable_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) swapped_var_list = {} for key, var in var_list.items(): @@ -464,4 +465,4 @@ class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): def after_create_session(self, session, coord): """Run initialization ops""" - session.run(self._variable_init_op) \ No newline at end of file + session.run(self._variable_init_op) diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py index a8dafd9a4cb9c669400f74b545b3c165bd49b2a2..bc18177b6d0b1d3f4fc58236bbc3d445fb73d80d 100644 --- a/tensorflow/contrib/opt/python/training/lars_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -162,3 +163,14 @@ class LARSOptimizer(optimizer.Optimizer): math_ops.cast(self._momentum_tensor, grad.dtype), use_locking=self._use_locking, use_nesterov=self._use_nesterov) + + def _prepare(self): + learning_rate = self._learning_rate + if callable(learning_rate): + learning_rate = learning_rate() + self._learning_rate_tensor = ops.convert_to_tensor( + learning_rate, name="learning_rate") + momentum = self._momentum + if callable(momentum): + momentum = momentum() + self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index b7fd2d2fb9db3eed15eb1cc2934199939790b1c0..bf3e5c51f78cc3ca3c7c77009c9cf428c4988953 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer from tensorflow.python.training import saver +from tensorflow.python.training.saving import saveable_object_util class MovingAverageOptimizer(optimizer.Optimizer): @@ -165,7 +166,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): - var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + var_list = saveable_object_util.op_list_to_dict(var_list) v_name_to_tensor = {} for k, tensor_or_list in six.iteritems(var_list): diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 155ff5b3f4f29d4d9c81bb265d19d1b8cce4fef2..960826407b66b4efa3c2693efb6d2e17c4b47b33 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -83,14 +84,14 @@ class NadamOptimizer(adam.AdamOptimizer): with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # m_bar = (1 - beta1) * g_t + beta1 * m_t - m_bar = m_scaled_g_values + beta1_t * m_t + m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) - v_sqrt = math_ops.sqrt(v_t) - var_update = state_ops.assign_sub( - var, lr * m_bar / (v_sqrt + epsilon_t), use_locking=self._use_locking) + v_t_slice = array_ops.gather(v_t, indices) + v_sqrt = math_ops.sqrt(v_t_slice) + var_update = scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t)) return control_flow_ops.group(*[var_update, m_bar, v_t]) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index 85e05ce71cec6ef897cadb7d123e630febb3c064..a4372f64874e7591dbceac901fad6c941209bef9 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -52,14 +52,19 @@ def nadam_update_numpy(param, class NadamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): + # need to use a larger value of epsilon here so that + # np.sqrt(v_t) + epsilon doesn't get rounded to 0 when + # the dtype is half and np.sqrt(v_t) = 0, as is the case + # when the gradient is 0 + sparse_epsilon = 1e-7 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype) if use_resource: var0 = resource_variable_ops.ResourceVariable(var0_np) @@ -67,21 +72,21 @@ class NadamOptimizerTest(test.TestCase): else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0_np_indices = np.array([0, 2], dtype=np.int32) grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([2])) - opt = nadam_optimizer.NadamOptimizer() + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) + opt = nadam_optimizer.NadamOptimizer(epsilon=sparse_epsilon) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 3.0, 4.0], var1.eval()) beta1_power, beta2_power = opt._get_beta_accumulators() @@ -91,8 +96,10 @@ class NadamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) update.run() - var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0, + epsilon=sparse_epsilon) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1, + epsilon=sparse_epsilon) # Validate updated params self.assertAllCloseAccordingToType(var0_np, var0.eval()) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 200b0d200826a6212a236680327f4daf7d07831f..8b8065c678e11e8fc237e71cf1d392ced5c22ada 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -59,6 +59,23 @@ class DecoupledWeightDecayExtension(object): Note that this extension decays weights BEFORE applying the update based on the gradient, i.e. this extension only has the desired behaviour for optimizers which do not depend on the value of'var' in the update step! + + Note: when applying a decay to the learning rate, be sure to manually apply + the decay to the `weight_decay` as well. For example: + + ```python + schedule = tf.train.piecewise_constant(tf.train.get_global_step(), + [10000, 15000], [1e-0, 1e-1, 1e-2]) + lr = 1e-1 * schedule() + wd = lambda: 1e-4 * schedule() + + # ... + + optimizer = tf.contrib.opt.MomentumWOptimizer(learning_rate=lr, + weight_decay=wd, + momentum=0.9, + use_nesterov=True) + ``` """ def __init__(self, weight_decay, **kwargs): diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 3ba3ee29ec79687df522eb330665a2ce80061682..6e401406308604970677003aeea0f15c64cc74b6 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -48,7 +48,6 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", @@ -56,6 +55,8 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", ], ) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 467dd86d8fd247a42be2dc47d5bf9872e14da89e..7fb23abc38d9dc101204ed83808aebe5a8ef1e78 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,6 +24,9 @@ import abc import six +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -34,8 +37,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.training import distribution_strategy_context as distribute_ctx from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -446,7 +447,7 @@ class _OptimizerV2State(object): if v is None: if colocate_with is None: colocate_with = self._non_slot_devices - with self._distribution.colocate_vars_with(colocate_with): + with self._distribution.extended.colocate_vars_with(colocate_with): # TODO(josh11b): Use get_variable() except for the legacy Adam use case. v = variable_scope.variable(initial_value, name=name, trainable=False) self._non_slot_dict[name] = v @@ -657,7 +658,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=None, gate_gradients=GATE_OP, aggregation_method=None, - colocate_gradients_with_ops=False, name=None, grad_loss=None, stop_gradients=None, @@ -680,8 +680,6 @@ class OptimizerV2(optimizer_v1.Optimizer): `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with the - corresponding op. name: Optional name for the returned operation. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. stop_gradients: Optional. A Tensor or list of tensors not to differentiate @@ -704,8 +702,8 @@ class OptimizerV2(optimizer_v1.Optimizer): Minimization (and gradient computation) is done with respect to the elements of `var_list` if not None, else with respect to any trainable variables created during the execution of the `loss` function. - `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and - `grad_loss` are ignored when eager execution is enabled. + `gate_gradients`, `aggregation_method`, and `grad_loss` are ignored when + eager execution is enabled. @end_compatibility """ grads_and_vars = self.compute_gradients( @@ -713,7 +711,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=var_list, gate_gradients=gate_gradients, aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, grad_loss=grad_loss, stop_gradients=stop_gradients, scale_loss_by_num_replicas=scale_loss_by_num_replicas) @@ -733,7 +730,6 @@ class OptimizerV2(optimizer_v1.Optimizer): var_list=None, gate_gradients=GATE_OP, aggregation_method=None, - colocate_gradients_with_ops=False, grad_loss=None, stop_gradients=None, scale_loss_by_num_replicas=None): @@ -756,8 +752,6 @@ class OptimizerV2(optimizer_v1.Optimizer): `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. - colocate_gradients_with_ops: If True, try colocating gradients with the - corresponding op. grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. @@ -776,8 +770,8 @@ class OptimizerV2(optimizer_v1.Optimizer): not callable. @compatibility(eager) - When eager execution is enabled, `gate_gradients`, `aggregation_method`, - and `colocate_gradients_with_ops` are ignored. + When eager execution is enabled, `gate_gradients`, and `aggregation_method` + are ignored. @end_compatibility """ # TODO(josh11b): Test that we handle weight decay in a reasonable way. @@ -832,7 +826,6 @@ class OptimizerV2(optimizer_v1.Optimizer): grad_ys=grad_loss, gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP), aggregation_method=aggregation_method, - colocate_gradients_with_ops=colocate_gradients_with_ops, stop_gradients=stop_gradients) if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH: grads = control_flow_ops.tuple(grads) @@ -848,8 +841,7 @@ class OptimizerV2(optimizer_v1.Optimizer): """Scale loss for the number of replicas.""" if scale_loss_by_num_replicas is None: scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) + distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: num_replicas = \ distribute_ctx.get_distribution_strategy().num_replicas_in_sync @@ -892,7 +884,8 @@ class OptimizerV2(optimizer_v1.Optimizer): raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) return distribute_ctx.get_replica_context().merge_call( - self._distributed_apply, filtered, global_step=global_step, name=name) + self._distributed_apply, args=(filtered,), + kwargs={"global_step": global_step, "name": name}) def _get_or_create_state(self, var_list=None): """Either looks up or creates `_OptimizerV2State`. @@ -927,8 +920,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce( - variable_scope.VariableAggregation.SUM, grads_and_vars) + reduced_grads = distribution.extended.batch_reduce_to( + ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) @@ -944,7 +937,7 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.name_scope(name, self._name) as name: per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list) # Include the current value of any dynamic hyper parameters in `state`. - non_slot_devices = distribution.non_slot_devices(var_list) + non_slot_devices = distribution.extended.non_slot_devices(var_list) state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access self._hyper, distribution, non_slot_devices) @@ -989,7 +982,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Use the processors to update the variables. update_ops = [] for grad, var in grads_and_vars: - update_ops.extend(distribution.update(var, update, grad, grouped=False)) + update_ops.extend(distribution.extended.update( + var, update, args=(grad,), group=False)) # Give the child class a chance to do something after applying # gradients @@ -1001,12 +995,12 @@ class OptimizerV2(optimizer_v1.Optimizer): update_ops = control_flow_ops.group(update_ops) with ops.control_dependencies([update_ops]): - finish_updates = distribution.update_non_slot( - non_slot_devices, finish, grouped=False) - # We said grouped=False, which means finish_updates is always a list. - # It will be [None] when finish() returns None. - if finish_updates == [None]: - finish_updates = [update_ops] + finish_updates = distribution.extended.update_non_slot( + non_slot_devices, finish, group=False) + # We said group=False, which means finish_updates is always a tuple. + # It will be (None,) when finish() returns None. + if finish_updates == (None,): + finish_updates = (update_ops,) # Update `global_step` (if any). if global_step is None: @@ -1017,8 +1011,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def update_global_step(global_step, name): return global_step.assign_add(1, read_value=False, name=name) - apply_updates = distribution.update(global_step, update_global_step, - name) + apply_updates = distribution.extended.update( + global_step, update_global_step, args=(name,)) # Add the training op to the TRAIN_OP graph collection in graph mode. if not eager_execution: diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index d50b52b8ff1ce8188ab52c6968d716378efd9daa..53a3bc63e1d770b451846c45370fdee9ffa72d70 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -42,6 +42,7 @@ py_library( name = "saved_model_predictor", srcs = ["saved_model_predictor.py"], srcs_version = "PY2AND3", + visibility = ["//learning/brain/contrib/learn/tpu:__subpackages__"], deps = [ ":base_predictor", "//tensorflow/contrib/saved_model:saved_model_py", diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index a1f2b5902663e96bca8e13998869f4a0e9ae584b..9085d9fa719520ac84ef6f8e07d7fa335bef5605 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -28,7 +28,7 @@ Since it's difficult to add these fake quantization operations to all the required locations in the model, there's a function available that rewrites the training graph. To create a fake quantized training graph: -``` +```python # Build forward pass of model. loss = tf.losses.get_total_loss() @@ -51,7 +51,7 @@ The rewritten *eval graph* is non-trivially different from the *training graph* since the quantization ops affect the batch normalization step. Because of this, we've added a separate rewrite for the *eval graph*: -``` +```python # Build eval model logits = tf.nn.softmax_cross_entropy_with_logits_v2(...) diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 6f659347fba019288361dd0420f2ade6dc1bebaf..8619708cdaecd78bcc7de0e8e0cbf2baa11bf6a2 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -138,7 +138,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -147,7 +147,7 @@ def LastValueQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: @@ -263,7 +263,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_min = math_ops.reduce_min( - inputs, reduction_indices=reduce_dims, name='BatchMin') + inputs, axis=reduce_dims, name='BatchMin') else: batch_min = inputs else: @@ -272,7 +272,7 @@ def MovingAvgQuantize(inputs, if per_channel: if input_dim >= 2: batch_max = math_ops.reduce_max( - inputs, reduction_indices=reduce_dims, name='BatchMax') + inputs, axis=reduce_dims, name='BatchMax') else: batch_max = inputs else: diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 338923f75125ed3d1a2b1046a65d563bc8f7d3e3..21d1b1213090273b5abd8e012f8711db98c94347 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -160,7 +160,7 @@ def Quantize(graph, # shouldn't quantize it, since the activation will be Fused into the # Add at inference time. consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op) - if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]): + if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers): logging.info('Skipping %s, because its followed by an activation.', layer_match.bypass_op.name) else: @@ -195,7 +195,7 @@ def Quantize(graph, # Add at inference time. consumers = input_to_ops_map.ConsumerOperations( layer_match.post_activation_bypass_op) - if any([consumer.type in _RELU_TYPES for consumer in consumers]): + if any(consumer.type in _RELU_TYPES for consumer in consumers): logging.info('Skipping %s, because its followed by an activation.', layer_match.post_activation_bypass_op.name) else: diff --git a/tensorflow/contrib/rate/BUILD b/tensorflow/contrib/rate/BUILD index c461a7145e27c4238161cec989448be807acd543..76db9aecf615d0a94f65cd7ea799db245828db1c 100644 --- a/tensorflow/contrib/rate/BUILD +++ b/tensorflow/contrib/rate/BUILD @@ -34,6 +34,11 @@ py_test( name = "rate_test", size = "small", srcs = ["rate_test.py"], + tags = [ + "manual", # TODO(b/120555555) + "no_oss", # TODO(b/120555555) + "notap", # TODO(b/120555555) + ], deps = [ ":rate", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index 38fcca03116721f3dabfa6d1e7122c369b6b405d..bbf109967595a73a0fc4bacaf34859b30c2376fc 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -13,6 +13,7 @@ load( ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") tf_custom_op_py_library( name = "resampler_py", @@ -50,10 +51,14 @@ tf_kernel_library( prefix = "resampler_ops", deps = [ ":resampler_ops_op_lib", - "//tensorflow/compiler/tf2xla/kernels:resampler_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", + ], + "//conditions:default": [], + }), alwayslink = 1, ) @@ -94,3 +99,26 @@ cuda_py_test( "//tensorflow/python:array_ops", ], ) + +tf_xla_py_test( + name = "resampler_ops_xla_test", + size = "small", + srcs = ["xla/resampler_ops_xla_test.py"], + disabled_backends = [ + # TODO(b/74459949) Support BatchDot in CPU backend. + "cpu", + "cpu_ondemand", + ], + # TODO(b/112295522): the OSS build will not likely work in the short to medium term, currently it is blocked by the fact that bazel does not allow py_library to depend on cc_library: https://github.com/bazelbuild/bazel/issues/701 which may not be resolvable. + tags = ["no_oss"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", + "//tensorflow/contrib/resampler:resampler_ops", + "//tensorflow/contrib/resampler:resampler_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py similarity index 76% rename from tensorflow/compiler/tests/resampler_ops_test.py rename to tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py index f87ac3360c905d7956ab3716c47d42765949774d..d8ca0eab276b39f025d018edebb78eed7a8433bb 100644 --- a/tensorflow/compiler/tests/resampler_ops_test.py +++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py @@ -63,8 +63,8 @@ class ResamplerOpsTest(xla_test.XLATestCase): def testSimple(self): for dtype in self.float_types: input_shape = [1, 2, 2, 1] - input_rgb_data = [0, 5, 13, 54] - input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + input_data = [0, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) warp_shape = [1, 2] warp_data = [0.7, 0.6] @@ -151,6 +151,55 @@ class ResamplerOpsTest(xla_test.XLATestCase): expected_grad_data, expected_grad_warp) + def testOutOfBoundWarps(self): + # (x, y) are both less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, -1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, 0.1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # Both of (x, y) are greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-0.1, 0.1, 1.2, 2.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [0.1, -0.1, 1.2, 0.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 245fa68eaef43ca8bc18c6087460d946228b0c85..7d57b0413a3bb51c35e670ce3fdb2cc818f44a58 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -906,7 +906,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoOutput(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) res = self._testDropoutWrapper( input_keep_prob=keep_all, output_keep_prob=keep_none, @@ -922,7 +922,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) # Even though we dropout state, by default DropoutWrapper never # drops out the memory ("c") term of an LSTMStateTuple. res = self._testDropoutWrapper( @@ -943,7 +943,7 @@ class DropoutWrapperTest(test.TestCase): def testDropoutWrapperKeepNoInput(self): keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-10) + keep_none = variable_scope.get_variable("none", initializer=1e-6) true_full_output = np.array( [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], dtype=np.float32) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 5cba54dd3df5bbb33380505bd5a073f069a3a590..ef372b947cedf71e9d44423f10cc43375b467cd9 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -227,7 +227,7 @@ class RNNTest(test.TestCase): def testDropout(self): cell = Plus1RNNCell() full_dropout_cell = rnn_cell.DropoutWrapper( - cell, input_keep_prob=1e-12, seed=0) + cell, input_keep_prob=1e-6, seed=0) (name, dep), = full_dropout_cell._checkpoint_dependencies self.assertIs(dep, cell) self.assertEqual("cell", name) diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index b30ca7882fce1747cb1dcb27f97f5b012ff9da02..251a933eaec826b08266123245d9aef8573d3e06 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -21,7 +21,7 @@ from tensorflow.contrib.rnn.ops import gen_gru_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -165,7 +165,7 @@ class GRUBlockCell(LayerRNNCell): num_units = cell_size self._cell_size = num_units # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 4db431f85a467389717e98d87875afce5e08b974..b043026bc556a8879b15b432829baf8136250c0e 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.rnn.ops import gen_lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -385,7 +386,7 @@ class LSTMBlockCell(LayerRNNCell): "scope": "lstm_cell" } # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -628,7 +629,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): self._use_peephole = use_peephole # Inputs must be 3-dimensional. - self.input_spec = base_layer.InputSpec(ndim=3) + self.input_spec = input_spec.InputSpec(ndim=3) @property def num_units(self): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index e159dc95796e8f02287a4b6db4d25023348fe8da..8a1c09f171e6108174671e3122d5ff4c0b236003 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import initializers -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_array_ops @@ -2752,7 +2752,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._activation = activation or math_ops.tanh # Restrict inputs to be 2-dimensional matrices - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -3089,7 +3089,7 @@ class IndRNNCell(rnn_cell_impl.LayerRNNCell): super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3183,7 +3183,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3323,7 +3323,7 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._forget_bias = forget_bias @@ -3444,7 +3444,7 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) @@ -3558,7 +3558,7 @@ class CFNCell(rnn_cell_impl.LayerRNNCell): super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index f0947fe423f7e6bf84dae468bc36ca11147ac0bb..269443b2c6508bb618d30f64487b1a6a84e8646f 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -102,7 +102,10 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", # TODO(b/119349471): Re-enable + "no_windows", + ], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 27b5b6d22e0fc1156d6f7a1c852f4c5a6e06da02..ffba514bb96f5ce8d963cb0a0482738eafe88355 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -25,7 +25,6 @@ from tensorflow.python.client import session from tensorflow.python.estimator import keras as estimator_keras_util from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib @@ -126,7 +125,7 @@ def save_keras_model( export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) temp_export_dir = export_helpers.get_temp_export_dir(export_dir) - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + builder = saved_model_builder._SavedModelBuilder(temp_export_dir) # Manually save variables to export them in an object-based checkpoint. This # skips the `builder.add_meta_graph_and_variables()` step, which saves a @@ -228,9 +227,10 @@ def _export_mode( g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) # Extract update and train ops from train/test/predict functions. + train_op = None if mode == model_fn_lib.ModeKeys.TRAIN: clone._make_train_function() - builder._add_train_op(clone.train_function.updates_op) + train_op = clone.train_function.updates_op elif mode == model_fn_lib.ModeKeys.EVAL: clone._make_test_function() else: @@ -265,7 +265,8 @@ def _export_mode( model_fn_lib.EXPORT_TAG_MAP[mode], signature_def_map=_create_signature_def_map(clone, mode), saver=saver_lib.Saver(clone_var_list), - main_op=variables.local_variables_initializer()) + init_op=variables.local_variables_initializer(), + train_op=train_op) return None @@ -307,31 +308,11 @@ def _create_signature_def_map(model, mode): serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): +def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument """Assert model and clone contain the same checkpointable objects.""" - def get_non_optimizer_objects(m, g): - """Gather set of model and optimizer checkpointable objects.""" - # Set default graph because optimizer.variables() returns optimizer - # variables defined in the default graph. - with g.as_default(): - all_objects = set(checkpointable_utils.list_objects(m)) - optimizer_and_variables = set() - for obj in all_objects: - if isinstance(obj, optimizers.TFOptimizer): - optimizer_and_variables.update(checkpointable_utils.list_objects(obj)) - optimizer_and_variables.update(set(obj.optimizer.variables())) - return all_objects - optimizer_and_variables - - model_objects = get_non_optimizer_objects(model, model_graph) - clone_objects = get_non_optimizer_objects(clone, clone_graph) - - if len(model_objects) != len(clone_objects): - raise errors.InternalError( - None, None, - 'Model and clone must use the same variables.' - '\n\tModel variables: %s\n\t Clone variables: %s' - % (model_objects, clone_objects)) + # TODO(fchollet, kathywu): make sure this works in eager mode. + return True def load_keras_model(saved_model_path): diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index a65b2ce466111c33d0092b7018537573708de2d0..93d73e1b484ed810fb347b13e95022dfca3584c2 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,14 +29,12 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import training as training_module @@ -255,7 +253,7 @@ def load_model(sess, path, mode): outputs = { k: sess.graph.get_tensor_by_name(v.name) for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()} - return inputs, outputs + return inputs, outputs, meta_graph_def @test_util.run_all_in_graph_and_eager_modes @@ -332,8 +330,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): # Load predict graph, and test predictions with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) predictions = sess.run(outputs[output_name], {inputs[input_name]: input_arr}) @@ -342,19 +340,21 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): if optimizer: # Load eval graph, and test predictions, loss and metric values with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.EVAL) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.EVAL) # First obtain the loss and predictions, and run the metric update op by # feeding in the inputs and targets. loss, predictions, _ = sess.run( (outputs['loss'], outputs['predictions/' + output_name], - outputs['metrics/mae/update_op']), - {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) + outputs['metrics/mean_absolute_error/update_op']), { + inputs[input_name]: input_arr, + inputs[target_name]: target_arr + }) # The metric value should be run after the update op, to ensure that it # reflects the correct value. - metric_value = sess.run(outputs['metrics/mae/value']) + metric_value = sess.run(outputs['metrics/mean_absolute_error/value']) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) @@ -364,17 +364,17 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): # Load train graph, and check for the train op, and prediction values with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.TRAIN) + inputs, outputs, meta_graph_def = load_model( + sess, output_path, model_fn_lib.ModeKeys.TRAIN) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) self.assertIn('loss', outputs) - self.assertIn('metrics/mae/update_op', outputs) - self.assertIn('metrics/mae/value', outputs) + self.assertIn('metrics/mean_absolute_error/update_op', outputs) + self.assertIn('metrics/mean_absolute_error/value', outputs) self.assertIn('predictions/' + output_name, outputs) # Train for a step - train_op = ops.get_collection(constants.TRAIN_OP_KEY) + train_op = loader_impl.get_train_op(meta_graph_def) train_outputs, _ = sess.run( [outputs, train_op], {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) @@ -401,8 +401,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): output_path = keras_saved_model.save_keras_model( model, saved_model_path, custom_objects={'relu6': relu6}) with session.Session(graph=ops.Graph()) as sess: - inputs, outputs = load_model(sess, output_path, - model_fn_lib.ModeKeys.PREDICT) + inputs, outputs, _ = load_model(sess, output_path, + model_fn_lib.ModeKeys.PREDICT) input_name = model.input_names[0] output_name = model.output_names[0] predictions = sess.run( @@ -463,11 +463,6 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - with self.assertRaisesRegexp( - errors.InternalError, 'Model and clone must use the same variables.'): - keras_saved_model._assert_same_non_optimizer_objects( - model, model_graph, clone, clone_graph) - def testSaveSeqModelWithoutInputShapesRaisesError(self): """A Sequential model that hasn't been built should raise an error.""" model = sequential_model_without_input_shape(True) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 8668c67cf95aba6cbd466142bed37c79e34d9e04..922f21b98b35dfff19c8c605a25e89c5d2da8d98 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -154,8 +154,8 @@ class AttentionWrapperTest(test.TestCase): if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. - attention_depth = sum([attention_layer_size or encoder_output_depth - for attention_layer_size in attention_layer_sizes]) + attention_depth = sum(attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py index 8fcd7aeef6a6964902666a4f3c17e05b0c7b52ee..f31bdbd399c9de4f2f5d557b75b1ece6d64a765e 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import lanczos from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -80,7 +81,8 @@ if __name__ == "__main__": for shape in [[4, 4], [7, 4], [5, 8]]: for orthogonalize in True, False: for steps in range(1, min(shape) + 1): - for use_static_shape in True, False: + # TF2 does not support placeholders so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_%s_%s_staticshape_%s" % ( dtype.__name__, "_".join(map(str, shape)), orthogonalize, steps, use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py index 2a9100903aae5689919a6b25fcb18ff192f250b3..841a41a2339824ab8ca15f4bdd74be697cd6fe9f 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import least_squares from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -76,7 +77,8 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for shape in [[4, 4], [8, 5], [3, 7]]: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, "_".join(map(str, shape)), use_static_shape) diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py index a0e6eb87bc06fb1303a7eb86fa6760458f20a9b9..10807f7a80617e56abeb6d13ce419a49a2269aac 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python import tf2 from tensorflow.contrib.solvers.python.ops import linear_equations from tensorflow.contrib.solvers.python.ops import util from tensorflow.python.framework import constant_op @@ -113,7 +114,8 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_): if __name__ == "__main__": for dtype in np.float32, np.float64: for size in 1, 4, 10: - for use_static_shape in True, False: + # TF2 does not support placeholders under eager so we skip it + for use_static_shape in set([True, tf2.enabled()]): shape = [size, size] arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, size, use_static_shape) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 4d1807130c57039976dfa57c27bb0d4807e75212..10e4556dacbc17ec02c2bd698389b04d517d7076 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -152,6 +152,27 @@ class EagerFileTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testRecordEveryNGlobalSteps(self): + step = training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + + def run_step(): + summary_ops.scalar('scalar', i, step=step) + step.assign_add(1) + + with summary_ops.create_file_writer( + logdir).as_default(), summary_ops.record_summaries_every_n_global_steps( + 2, step): + for i in range(10): + run_step() + # And another 10 steps as a graph function. + run_step_fn = function.defun(run_step) + for i in range(10): + run_step_fn() + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 11) + def testMaxQueue(self): logs = tempfile.mkdtemp() with summary_ops.create_file_writer( @@ -279,12 +300,9 @@ class EagerDbTest(summary_test_util.SummaryDbTest): def testDbURIOpen(self): tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite') - tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path) - tmpdb_writer = summary_ops.create_db_writer( - tmpdb_uri, - "experimentA", - "run1", - "user1") + tmpdb_uri = six.moves.urllib_parse.urljoin('file:', tmpdb_path) + tmpdb_writer = summary_ops.create_db_writer(tmpdb_uri, 'experimentA', + 'run1', 'user1') with summary_ops.always_record_summaries(): with tmpdb_writer.as_default(): summary_ops.scalar('t1', 2.0) diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc index 3f24f58f03aac2ba6d368d7eccf8731f611a81b4..22b6f09d0cd88068f7bedabe7687920420a3028f 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc @@ -73,7 +73,16 @@ class SummaryFileWriter : public SummaryWriterInterface { e->set_step(global_step); e->set_wall_time(GetWallTime()); Summary::Value* v = e->mutable_summary()->add_value(); - t.AsProtoTensorContent(v->mutable_tensor()); + + if (t.dtype() == DT_STRING) { + // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python + // can convert the TensorProto to string-type numpy array. MakeNdarray + // does not work with strings encoded by AsProtoTensorContent() in + // tensor_content. + t.AsProtoField(v->mutable_tensor()); + } else { + t.AsProtoTensorContent(v->mutable_tensor()); + } v->set_tag(tag); if (!serialized_metadata.empty()) { v->mutable_metadata()->ParseFromString(serialized_metadata); diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc index cd3f712256f2293ed725745f8cbe48109856ef86..ffbfb9533e887e54b0f5bdfde11dadce21073a94 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/path.h" @@ -104,6 +105,23 @@ TEST_F(SummaryFileWriterTest, WriteTensor) { CHECK_EQ(e.summary().value_size(), 1); EXPECT_EQ(e.summary().value(0).tag(), "name"); })); + TF_CHECK_OK(SummaryTestHelper( + "string_tensor_test", + [](SummaryWriterInterface* writer) { + Tensor hello(DT_STRING, TensorShape({})); + hello.scalar()() = "hello"; + TF_RETURN_IF_ERROR(writer->WriteTensor( + 2, hello, "name", SummaryMetadata().SerializeAsString())); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + EXPECT_EQ(e.summary().value(0).tensor().dtype(), DT_STRING); + EXPECT_EQ(e.summary().value(0).tensor().string_val()[0], "hello"); + })); } TEST_F(SummaryFileWriterTest, WriteScalar) { diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 20bcd2447e6fd7eaf11e3e5cf383f6abf168c787..784acce444a8d0c066f1b7ae6c1b5d7d65405549 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -29,6 +29,10 @@ load( "if_tensorrt", ) +exports_files(glob([ + "test/testdata/*", +])) + tf_cuda_cc_test( name = "tensorrt_test_cc", size = "small", @@ -491,6 +495,7 @@ cuda_py_tests( "test/memory_alignment_test.py", "test/multi_connection_neighbor_engine_test.py", "test/neighboring_engine_test.py", + "test/quantization_test.py", "test/rank_two_test.py", "test/reshape_transpose_test.py", "test/vgg_block_nchw_test.py", @@ -527,6 +532,30 @@ cuda_py_tests( ], ) +cuda_py_test( + name = "quantization_mnist_test", + srcs = ["test/quantization_mnist_test.py"], + additional_deps = [ + ":tf_trt_integration_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/keras:keras", + "//tensorflow/python/estimator:estimator", + ], + data = [ + "test/testdata/checkpoint", + "test/testdata/model.ckpt-46900.data-00000-of-00001", + "test/testdata/model.ckpt-46900.index", + ], + tags = [ + "no_cuda_on_cpu_tap", + "no_pip", + "no_tap", # It is not able to download the mnist data. + "no_windows", + "nomac", + ], +) + cc_library( name = "utils", srcs = ["convert/utils.cc"], diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 26d54eb156ccc8593d82609195caabb5bb929262..ae211a93c3279ff1d6de2f9c9a4b849fc8cd578d 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -82,60 +82,78 @@ std::vector GetLoadedTensorRTVersion() { } TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties) - : graph_properties_(graph_properties) {} + const grappler::GraphProperties& graph_properties, int precision_mode) + : graph_properties_(graph_properties), precision_mode_(precision_mode) {} Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", - "Exp", - "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", - "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", - "Max", - "Min", + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Sigmoid", + "Tanh", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", + "Transpose", + "Reshape", + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", + "Relu6", + "Square", + "ExpandDims", + "Squeeze", }; - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) - const bool is_supported_op_type = + bool is_supported_op_type = (candidate_ops.count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); + static const std::set quantize_ops = { + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxArgs", + }; + // In INT8 mode, we will always apply the quantization ranges provided by + // these ops to the relevant tensors. This happens regardless of the value of + // use_calibration. + if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { + is_supported_op_type = true; + } + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), - " is not supported."); + " is not supported"); } std::vector input_edges; @@ -170,7 +188,7 @@ tensorflow::Status BuildNodeMap( tensorflow::Status ConvertCalibGraphToInferGraph( const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, bool is_dyn_op) { - VLOG(0) << "Starting Calib Conversion"; + LOG(INFO) << "Starting Calib Conversion"; infer_graph->CopyFrom(graph_def); auto trt_rm = TRTResourceManager::instance(); auto calib_rm = trt_rm->getManager("TRTCalibration"); @@ -220,18 +238,19 @@ tensorflow::Status ConvertGraphDefToTensorRT( const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode, int minimum_segment_size, bool is_dyn_op, - int max_cached_engines, std::vector cached_engine_batches) { + int max_cached_engines, std::vector cached_engine_batches, + bool use_calibration) { // Create GrapplerItem. tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - // TODO(aaroey): we should have used single machine cluster like the - // following, but the problem is then wrap_conversion will depend on - // direct_session and cause double linking problems. To fix this we need to - // fix or get rid of the swig dependency. Here we use VirtualCluster - // as a work around, and we need to create a session to initialize the - // underlying device before calling this method. +// TODO(aaroey): we should have used single machine cluster like the +// following, but the problem is then wrap_conversion will depend on +// direct_session and cause double linking problems. To fix this we need to +// fix or get rid of the swig dependency. Here we use VirtualCluster +// as a work around, and we need to create a session to initialize the +// underlying device before calling this method. #if 0 // Create single machine cluster. Note that this will create a session and // initialize the gpu devices. @@ -264,7 +283,9 @@ tensorflow::Status ConvertGraphDefToTensorRT( #endif // Create RewriterConfig. - tensorflow::RewriterConfig rw_cfg; + tensorflow::ConfigProto config_proto; + auto& rw_cfg = + *config_proto.mutable_graph_options()->mutable_rewrite_options(); // TODO(aaroey): use only const folding and layout for the time being since // new optimizers break the graph for trt. rw_cfg.add_optimizers("constfold"); @@ -285,9 +306,10 @@ tensorflow::Status ConvertGraphDefToTensorRT( list->add_i(batch); } } + parameters["use_calibration"].set_b(use_calibration); // Run optimizer. - tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); + tensorflow::grappler::MetaOptimizer meta_opt(nullptr, config_proto); TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); if (VLOG_IS_ON(5)) { @@ -433,7 +455,8 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - LOG(ERROR) << "Can't find a device placement for the op!"; + VLOG(1) << "No device is assigned to the segment. " + << "A device will be assigned during graph execution (inference)."; } return Status::OK(); } @@ -564,27 +587,38 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } } + // We don't support segments with no inputs. Fall back to native TF here to + // avoid crash later. Constant folding should've folded the ops that make up + // these segments. + if (inputs.empty()) { + return tensorflow::errors::Internal( + "Segment has no inputs (possible " + "constfold failure)"); + } + + const bool calibrate_int8 = + (info.precision_mode == INT8MODE && info.use_calibration); + // Build the engine and get its serialized representation. string segment_string; - if (info.engine_type == EngineInfo::EngineType::TRTStatic || - info.precision_mode == INT8MODE) { + if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { // Create static engine for fp32/fp16 mode, and test validity of the engine - // for int8 mode. We don't want engine to fail at the calibration time. - // So we are constructing a FP32 engine here to check its validity, and if - // it is a valid engine then we put the serialized graphdef to the op. - // Otherwise we skip node creation for this engine. + // for int8 calibration mode. We don't want engine to fail at the + // calibration time. So we are constructing a FP32 engine here to check its + // validity, and if it is a valid engine then we put the serialized graphdef + // to the op. Otherwise we skip node creation for this engine. Logger trt_logger; TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, - info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, + info.use_calibration, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); segment_string = string((const char*)engine_data->data(), engine_data->size()); - if (info.precision_mode == INT8MODE) { + if (calibrate_int8) { // See above comment about why not putting this inside the 'else' branch. segment_string = info.segment_graph_def.SerializeAsString(); } @@ -596,7 +630,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, // conversion. string prec_string; TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); - if (info.precision_mode == INT8MODE && + if (info.precision_mode == INT8MODE && calibrate_int8 && !TRTResourceManager::instance()->getManager("TRTCalibration")) { LOG(ERROR) << "Failed to construct calibration storage"; } @@ -632,6 +666,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, .Attr("cached_engine_batches", {max_batch_size}) .Attr("workspace_size_bytes", info.max_workspace_size_bytes) .Attr("precision_mode", prec_string) + .Attr("use_calibration", info.use_calibration) .Attr("OutT", out_types) .Finalize(&trt_node); if (!status.ok()) { @@ -864,7 +899,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; - TrtCandidateSelector candidate_selector(*params.graph_properties); + TrtCandidateSelector candidate_selector(*params.graph_properties, + params.precision_mode); TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( &graph, std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector, @@ -873,10 +909,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // need to check the input edges. [](const Edge* edge) { return true; }, OutputEdgeValidator(), segment_options, &initial_segments)); - if (initial_segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + LOG(INFO) << "Number of TensorRT candidate segments: " << initial_segments.size(); - } // Get the EngineInfo for each segment. std::unordered_map node_map; @@ -902,13 +936,17 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { continue; } curr_engine.precision_mode = params.precision_mode; - curr_engine.engine_type = - (params.is_dyn_op || params.precision_mode == INT8MODE - ? EngineInfo::EngineType::TRTDynamic - : EngineInfo::EngineType::TRTStatic); + if (params.use_calibration && params.precision_mode != INT8MODE) { + return errors::InvalidArgument( + "Calibration with FP32 or FP16 is not supported."); + } + curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration) + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.use_calibration = params.use_calibration; curr_engine.cached_engine_batches = params.cached_engine_batches; curr_engine.maximum_cached_engines = params.max_cached_engines; - StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + StrAppend(&curr_engine.engine_name, "TRTEngineOp_", t); status = RegisterSegmentFunctionToFunctionLibrary( &graph, curr_engine.segment_graph_def, curr_engine.engine_name); if (!status.ok()) { @@ -969,16 +1007,9 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. - string msg = StrCat("Engine ", engine.engine_name, " creation for segment ", - i, ", composed of ", + string msg = StrCat("TensorRT node ", engine.engine_name, + " added for segment ", i, " consisting of ", converted_segments.at(i).first.size(), " nodes"); - if (VLOG_IS_ON(1)) { - StrAppend(&msg, " ("); - for (const string& node_name : converted_segments.at(i).first) { - StrAppend(&msg, node_name, ", "); - } - StrAppend(&msg, ")"); - } if (status.ok()) { LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { @@ -986,7 +1017,14 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } } else { // Graph is not modified. - LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Fallback to TF..."; + } + if (VLOG_IS_ON(1)) { + msg = "Segment consists of nodes: "; + for (const string& node_name : converted_segments.at(i).first) { + StrAppend(&msg, node_name, ", "); + } + VLOG(1) << msg; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 1c9d82105a7b380cafbb27c340a4cc9d1580ee2c..1f39f56f6392ba33af3d74fec12c326ed4451cb6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -35,7 +35,8 @@ namespace convert { // supported by TRT. class TrtCandidateSelector { public: - TrtCandidateSelector(const grappler::GraphProperties& graph_properties); + TrtCandidateSelector(const grappler::GraphProperties& graph_properties, + int precision_mode); // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added // to TRT subgraph and later converted into TRT engine. @@ -49,6 +50,9 @@ class TrtCandidateSelector { // GraphProperties of the graph whose nodes are to be validated by // IsTensorRTCandidate(). const grappler::GraphProperties& graph_properties_; + + // Quantization ops are only converted when using quantized precisions. + const int precision_mode_; }; struct ConversionParams { @@ -63,6 +67,7 @@ struct ConversionParams { cluster(nullptr), is_dyn_op(false), fixed_input_size(true), + use_calibration(true), max_cached_engines(1) {} const tensorflow::GraphDef* input_graph_def; const std::vector* output_names; @@ -76,6 +81,7 @@ struct ConversionParams { bool is_dyn_op; // Whether to create engine on conversion or execution time bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed int max_cached_engines; // maximum number of cached engines + bool use_calibration; std::vector cached_engine_batches; // list of cached engines }; @@ -95,7 +101,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode = 1, int minimum_segment_size = 3, bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}); + std::vector cached_engine_batches = {}, bool use_calibration = true); // Method to call from optimization pass tensorflow::Status ConvertAfterShapes(ConversionParams& params); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index f10729987fdb787c6a745fdac28fe7d7d60d08fa..2d2bfeb192c1893824c7b30bfad593c62c203392 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -85,27 +85,42 @@ TEST(TrtCandidateSelector, Basics) { ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), incompatible_feed, const_2); + // Quantize ops. + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed, + quantize_attrs); + + // Get GrapplerItem and GraphProperties. grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); Tensor feed_tensor(DT_FLOAT, input_shape); item.feed.push_back(std::make_pair("feed", feed_tensor)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - TrtCandidateSelector selector(graph_properties); - TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); - ExpectStatus( - selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), - error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected " - "(op: MatMul), at: incompatible_matmul"); - ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); - ExpectStatus(selector.IsTensorRTCandidate( - matmul_with_incompatible_input.operation.node()), - error::INTERNAL, - "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + for (const int precision_mode : {FP32MODE, INT8MODE}) { + TrtCandidateSelector selector(graph_properties, precision_mode); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus( + selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + if (precision_mode == INT8MODE) { + TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); + } else { + ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), + error::UNIMPLEMENTED, + "Op type FakeQuantWithMinMaxArgs is not supported"); + } + } } class FakeCluster : public grappler::Cluster { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index e2988f5f2a8f6164cbe193573b267e6ffeef3284..777a80bbc4da7a260cf85d0a7bc5ec16f4cd3cab 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -54,10 +54,10 @@ limitations under the License. // would work! #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) -#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ - do { \ - return tensorflow::errors::Internal( \ - "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, " failed to add TRT layer, at: ", node); \ } while (0) #define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ @@ -120,6 +120,15 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape, return trt_dims; } +Status TensorShapeArrayToTrtDims(const std::vector& shape, + nvinfer1::Dims* out, + bool ignore_first_dim = false) { + PartialTensorShape tensor_shape; + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(shape, &tensor_shape)); + *out = TensorShapeToTrtDims(tensor_shape, ignore_first_dim); + return tensorflow::Status::OK(); +} + void GetOutputProperties(const grappler::GraphProperties& graph_properties, const Node* node, const int out_port, PartialTensorShape* shape, @@ -130,7 +139,7 @@ void GetOutputProperties(const grappler::GraphProperties& graph_properties, *dtype = out_shape.dtype(); *shape = out_shape.shape(); } else { - VLOG(0) << "Unknown output shape" << node->name(); + LOG(INFO) << "Unknown output shape" << node->name(); *dtype = node->output_type(out_port); } } @@ -181,16 +190,55 @@ Status ValidateTensorProperties(const string& producer_node_type, if (shape.dim_size(d) < 0) { return errors::InvalidArgument( "Input tensor with shape ", shape.DebugString(), - " has an unknown non-batch dimemension at dim ", d); + " has an unknown non-batch dimension at dim ", d); } } return Status::OK(); } +string DebugString(const nvinfer1::DimensionType type) { + switch (type) { + case nvinfer1::DimensionType::kSPATIAL: + return "kSPATIAL"; + case nvinfer1::DimensionType::kCHANNEL: + return "kCHANNEL"; + case nvinfer1::DimensionType::kINDEX: + return "kINDEX"; + case nvinfer1::DimensionType::kSEQUENCE: + return "kSEQUENCE"; + default: + return StrCat(static_cast(type), "=unknown"); + } +} + +string DebugString(const nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return "kFLOAT"; + case nvinfer1::DataType::kHALF: + return "kHALF"; + case nvinfer1::DataType::kINT8: + return "kINT8"; + case nvinfer1::DataType::kINT32: + return "kINT32"; + default: + return "Invalid TRT data type"; + } +} + string DebugString(const nvinfer1::Dims& dims) { string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); for (int i = 0; i < dims.nbDims; ++i) { - StrAppend(&out, dims.d[i], ","); + StrAppend(&out, dims.d[i], "[", DebugString(dims.type[i]), "],"); + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::Permutation& permutation, int len) { + string out = "nvinfer1::Permutation("; + for (int i = 0; i < len; ++i) { + StrAppend(&out, permutation.order[i], ","); } StrAppend(&out, ")"); return out; @@ -198,16 +246,15 @@ string DebugString(const nvinfer1::Dims& dims) { string DebugString(const nvinfer1::ITensor& tensor) { return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), - ", shape=", DebugString(tensor.getDimensions()), ")"); + ", name=", tensor.getName(), + ", dtype=", DebugString(tensor.getType()), + ", dims=", DebugString(tensor.getDimensions()), ")"); } -// Return whether or not the broadcast is feasible; -bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, - const bool operand_l_is_tensor, - const nvinfer1::Dims& operand_r, - const bool operand_r_is_tensor, - nvinfer1::Dims* operand_l_new_shape, - nvinfer1::Dims* operand_r_new_shape) { +Status Converter::GetTrtBroadcastShape( + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const { // *************************************************************************** // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank @@ -232,52 +279,59 @@ bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, // -> T: 1 1 1 -1 3 5 1 // -> W: 1 1 1 1 3 5 1 // *************************************************************************** - const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - const size_t element_size = sizeof(operand_l.d[0]); - - // fill in dimensions - int l_s[max_nb_dims]; - std::fill(l_s, l_s + max_nb_dims, 1); - int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; - int r_s[max_nb_dims]; - std::fill(r_s, r_s + max_nb_dims, 1); - int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; - - int max_d = std::max(l_d, r_d); - std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, - operand_l.nbDims * element_size); - std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, - operand_r.nbDims * element_size); - - // set -1 for batch dimension, since batch size is not supposed to be - // broadcasted - if (operand_l_is_tensor) { - if (max_d != l_d) { // if broadcast beyond batch dimension, fail - return false; - } - l_s[0] = -1; - } - if (operand_r_is_tensor) { - if (max_d != r_d) { // if broadcast beyond batch dimension, fail - return false; - } - r_s[0] = -1; + if (!operand_l.is_tensor() && !operand_r.is_tensor()) { + return errors::InvalidArgument( + "Broadcasting requires at least one of the operands be tensors"); } - // compare broadcast feasibility - for (int i = max_d - 1; i >= 0; i--) { - if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { - return false; + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + auto compute_output_dims = + [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, + int* output_dims_array, nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; + + // Compute the output dimensions. + const int broadcast_num_dims = + std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), + operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); + int output_l[max_nb_dims], output_r[max_nb_dims]; + TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, + output_l, operand_l_new_dims)); + TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, + output_r, operand_r_new_dims)); + + // Compare broadcast feasibility + for (int i = 0; i < broadcast_num_dims; ++i) { + if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && + (output_r[i] != 1)) { + return errors::InvalidArgument( + "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", + DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], + ", ", DebugString(*operand_r_new_dims), ")"); } } - - // output new TensorRT Dimension (stripping the batch dimension) - operand_l_new_shape->nbDims = max_d - 1; - std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); - operand_r_new_shape->nbDims = max_d - 1; - std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); - - return true; + return Status::OK(); } inline bool DimsEqual(const nvinfer1::Dims& dim_l, @@ -381,7 +435,7 @@ size_t TRT_ShapedWeights::size_bytes() const { string TRT_ShapedWeights::DebugString() const { return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", type_, + ", type=", DataTypeString(type_), ", values=", reinterpret_cast(GetValues()), ")"); } @@ -425,7 +479,9 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { void setLocation(nvinfer1::TensorLocation location) override {} #if NV_TENSORRT_MAJOR >= 5 - bool setDynamicRange(float min, float max) override {} + bool setDynamicRange(float min, float max) override { return true; } + + float getDynamicRange() const override { return 0; } #endif private: @@ -489,8 +545,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor()), - ", shape=", convert::DebugString(tensor()->getDimensions()), + StrAppend(&output, "tensor=", convert::DebugString(*tensor()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -753,8 +808,9 @@ Status TrtNodeValidator::ValidateNode( Status status = ConvertToTensorOrWeights( *pair.first, pair.second, graph_properties, &tensor_or_weights); if (!status.ok()) { - return errors::Internal("Failed to convert input with index ", i, - " to a TRT_TensorOrWeights"); + return errors::Internal( + "Failed to convert input with index ", i, + " to a TRT_TensorOrWeights: ", status.error_message()); } inputs.push_back(tensor_or_weights); } @@ -786,8 +842,11 @@ Status TrtNodeValidator::ConvertConstToWeights( return status; } -Converter::Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16) - : trt_network_(trt_network), is_fp16_(is_fp16) { +Converter::Converter(nvinfer1::INetworkDefinition* trt_network, + int precision_mode, bool use_calibration) + : trt_network_(trt_network), + precision_mode_(precision_mode), + use_calibration_(use_calibration) { this->RegisterOpConverters(); } @@ -812,13 +871,18 @@ Status Converter::ConvertNode(const NodeDef& node_def) { TRT_TensorOrWeights& output = outputs[i]; string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); - // We need to check the name before setting it. For Identity op where the - // output is the input, if its input is one of the engine input, setting - // the name here will overwrite engine input bindings which will cause - // runtime error. + // We need to check the name before setting it. If the input is one of the + // engine input, setting the name here will overwrite engine input + // bindings which will cause runtime error. if (output.is_tensor()) { const char* tensor_name = output.tensor()->getName(); - if (tensor_name == nullptr || std::strlen(tensor_name) == 0) { + if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) { + // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename + // them to match their corresponding TensorFlow name. + // Note: ITensors that we create internally within TF-TRT which are + // not inputs or outputs of a node will not be renamed. This is a + // potential cause of confusion if an error message or warning + // mentions the unnamed tensor. output.tensor()->setName(output_name.c_str()); } } @@ -930,11 +994,14 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose"); + MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order_with_batch_dim[i + 1] - 1; } + VLOG(1) << "TransposeTensor permutation: " + << DebugString(permutation, dims.nbDims); layer->setFirstTranspose(permutation); nvinfer1::Dims reshape_dims; @@ -950,6 +1017,38 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, return tensorflow::Status::OK(); } +Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, + float* out_min, float* out_max) const { + switch (weights.type_) { + case DataType::DT_FLOAT: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = *result.first; + *out_max = *result.second; + break; + } + case DataType::DT_HALF: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = Eigen::half_impl::half_to_float(*result.first); + *out_max = Eigen::half_impl::half_to_float(*result.second); + break; + } + case DataType::DT_INT32: { + auto inp = static_cast(weights.GetValues()); + auto result = std::minmax_element(inp, inp + weights.count()); + *out_min = static_cast(*result.first); + *out_max = static_cast(*result.second); + break; + } + default: + return errors::Unimplemented( + "Data type not supported for GetWeightRange: ", + DataTypeString(weights.type_)); + } + return Status::OK(); +} + Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const nvinfer1::ITensor** tensor) { @@ -964,8 +1063,9 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, } if (can_check_shapes && TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) { - return tensorflow::errors::InvalidArgument( - "Reshape shapes are not compatible."); + return errors::InvalidArgument("Reshape shapes are not compatible (", + DebugString(input.GetTrtDims()), " vs ", + DebugString(dims), ")"); } if (input.is_tensor()) { @@ -976,6 +1076,8 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, *const_cast(input.tensor())); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); layer->setReshapeDimensions(dims); + MarkQuantizationRangesAsInferrable( + const_cast(input.tensor()), layer->getOutput(0)); *tensor = layer->getOutput(0); } } else { @@ -983,10 +1085,123 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, this->network()->addConstant(dims, input.weights().GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); *tensor = layer->getOutput(0); + if (precision_mode() == INT8MODE && !use_calibration()) { + // If we are in int8 mode and not calibrating, we need to explicitly set a + // quantization range for the output tensor of the IConstantLayer. Here we + // set the range to [min(weights), max(weights)]. + float min_range = 0.0f; + float max_range = 0.0f; + TF_RETURN_IF_ERROR( + GetWeightRange(input.weights(), &min_range, &max_range)); + // Avoid setting range to 0 because TRT will throw an error. If the + // weights are zero then the range doesn't matter: using 127.0f should + // ensure the quantized weight will be exactly zero. + if (min_range == 0.0f && max_range == 0.0f) { + min_range = -127.0f; + max_range = 127.0f; + } + ProvideQuantizationRange(const_cast(*tensor), + min_range, max_range); + } } return tensorflow::Status::OK(); } +void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, + nvinfer1::ITensor* output) { + quantization_infer_.push_back({input, output}); + quantization_infer_.push_back({output, input}); +} + +void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, + float min_range, float max_range) { + float symmetric_range = std::max(std::abs(min_range), std::abs(max_range)); + quantization_ranges_[tensor] = symmetric_range; +} + +void Converter::MaybeApplyQuantizationRanges() { + if (precision_mode() != INT8MODE) return; + + // Infer ranges across marked ops. + PropagateQuantizationRanges(); + // Apply ranges. +#if NV_TENSORRT_MAJOR >= 5 + for (auto pair : quantization_ranges_) { + nvinfer1::ITensor* tensor = pair.first; + const float range = pair.second; + VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range; + // TODO(laigd): if 'tensor' already has a range set which doesn't match + // 'range', it should report error. + tensor->setDynamicRange(-range, range); + } +#endif + + // Warn user about tensors that are missing ranges. If TRT fuses some layers + // then these tensors may not actually be required, which is why this is + // just a warning. If we are still missing ranges even after fusion, + // Builder::buildCudaEngine() will return nullptr and we will catch the + // error at that point. + if (!use_calibration()) { + // Get all tensors from network + std::set all_tensors; + for (int i = 0; i < this->network()->getNbLayers(); i++) { + nvinfer1::ILayer* layer = this->network()->getLayer(i); + for (int j = 0; j < layer->getNbInputs(); j++) { + all_tensors.insert(layer->getInput(j)); + } + for (int j = 0; j < layer->getNbOutputs(); j++) { + all_tensors.insert(layer->getOutput(j)); + } + } + // Find tensors with no ranges + for (auto tensor : all_tensors) { + if (!quantization_ranges_.count(tensor)) { + // Note: there may be some warnings for "(Unnamed ITensor* N)". These + // are tensors which are created internally by TF-TRT. The ranges for + // these unnamed ITensors are always inferred from user provided ranges, + // thus there will also be a warning for the range(s) the user missed. + LOG(WARNING) << "Quantization range was not found for " + << tensor->getName() << ". " + << "This is okay if TensorRT does not need the range " + << "(e.g. due to node fusion)."; + } + } + } +} + +void Converter::PropagateQuantizationRanges() { + // Propagate ranges across edges in quantization_infer_ until no new + // information is added. + // Note: this function modifies quantization_infer_, it might be better to + // modify a copy instead if we for some reason need quantization_infer_ + // later. + bool information_added = true; + while (information_added) { + information_added = false; + for (auto it = quantization_infer_.begin(); + it != quantization_infer_.end();) { + auto input_tensor_range = quantization_ranges_.find(it->first); + auto output_tensor_range = quantization_ranges_.find(it->second); + if (input_tensor_range != quantization_ranges_.end() && + output_tensor_range == quantization_ranges_.end()) { + // Input has range but output doesn't: copy range + // TODO(laigd): consider reporting error if it a different range is + // already set. + quantization_ranges_[it->second] = input_tensor_range->second; + information_added = true; + VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> " + << it->second->getName(); + } + // We can remove edges when the output range is known + if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) { + it = quantization_infer_.erase(it); + } else { + ++it; + } + } + } +} + Status Converter::GetInputs(const tensorflow::NodeDef& node_def, std::vector* inputs) const { for (auto const& input_name : node_def.input()) { @@ -1043,12 +1258,11 @@ TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, } // **************************************************************************** -// Constant folding functions -// TODO(jie): once optimizer kicks in, we should have done constant folding -// there. +// Constant folding functions for weights. +// TODO(laigd): we should probably use eigen directly. // ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP }; OP_CATEGORY op; template @@ -1063,84 +1277,10 @@ struct LambdaFactory { case OP_CATEGORY::RECIP: return [](T t) -> T { return 1.0 / t; }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } - - template - std::function binary() { - switch (op) { - case OP_CATEGORY::ADD: - return [](T l, T r) -> T { return l + r; }; - case OP_CATEGORY::SUB: - return [](T l, T r) -> T { return l - r; }; - case OP_CATEGORY::MUL: - return [](T l, T r) -> T { return l * r; }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [](T l, T r) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_r(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l + val; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l - val; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l * val; - }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_l(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val + l; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val - l; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val * l; - }; - default: - LOG(ERROR) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } }; template <> @@ -1148,15 +1288,18 @@ std::function LambdaFactory::unary() { switch (op) { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; - return [](Eigen::half t) -> Eigen::half { + return [](Eigen::half t) { return Eigen::half(1.0 / sqrt(static_cast(t))); }; } case OP_CATEGORY::NEG: - return [](Eigen::half t) -> Eigen::half { return -t; }; - // TODO(aaroey): can we support RECIP? + return [](Eigen::half t) { return -t; }; + case OP_CATEGORY::RECIP: + return [](Eigen::half t) { + return Eigen::half(1.0 / static_cast(t)); + }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } @@ -1188,50 +1331,48 @@ tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, return tensorflow::Status::OK(); } +// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the +// right operand. If swapped_inputs is true, those two are swapped. +// // TODO(jie): broadcast is needed yet not implemented. -// Only implemented channel wise for the time being -tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, - const nvinfer1::ITensor* tensor, - TRT_ShapedWeights weights, - bool swapped_inputs) { +// Only implemented channel wise for the time being. +Status BinaryTensorOpWeight(OpConverterParams* params, + const nvinfer1::ITensor* tensor, + TRT_ShapedWeights weights, bool swapped_inputs) { + static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", + "Div", "RealDiv"}; const auto& node_def = params->node_def; - // tensor is the left operand while weights is the right operand; - // when swapped_inputs set to true, those two are swapped. - // TODO(aaroey): use a set. - if (node_def.op() != "Sub" && node_def.op() != "Add" && - node_def.op() != "Mul" && node_def.op() != "Div" && - node_def.op() != "RealDiv") { - return tensorflow::errors::Unimplemented( - "op not supported: " + node_def.op() + ", at: " + node_def.name()); + if (!supported_ops.count(node_def.op())) { + return errors::Unimplemented(node_def.op(), " is not supported, at ", + node_def.name()); } - // Check type consistency - nvinfer1::DataType ttype; - TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype)); + // Check type consistency. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype)); - // Check scale mode + // Check scale mode. auto dims_w = weights.shape_; - auto dims_t = tensor->getDimensions(); + const auto dims_t = tensor->getDimensions(); // TODO(jie): addScale checks for input tensor dimension if (dims_t.nbDims != 3) { - return tensorflow::errors::InvalidArgument( - "addScale requires tensor with rank 3, " + node_def.name()); + return errors::InvalidArgument("addScale requires tensor with rank 3, at ", + node_def.name()); } - // default to element-wise + // Default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; // TODO(jie): maybe use a permutation instead to support more cases; - bool permutation_flag = false; + bool need_to_permute = false; if (weights.count() == 1) { - VLOG(2) << "UNIFORM"; scale_mode = nvinfer1::ScaleMode::kUNIFORM; } else { - // no broadcasting on Batch dimension; - VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims - << " tensor DIM: " << dims_t.nbDims; + VLOG(2) << "weights dims: " << DebugString(dims_w) + << "; tensor dims: " << DebugString(dims_t); + // Make sure no broadcasting on batch dimension. if (dims_w.nbDims == dims_t.nbDims + 1) { if (dims_w.d[0] == 1) { for (int i = 1; i < dims_w.nbDims; i++) { @@ -1239,72 +1380,70 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } dims_w.nbDims--; } else { - return tensorflow::errors::InvalidArgument( - "Binary op cannot operate on batch, " + node_def.name()); + return errors::InvalidArgument("Binary op cannot operate on batch, at ", + node_def.name()); } } if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - // default is element; + // Default is element-wise for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != dims_t.d[i]) { - // if dimension does not match, switch back to channel; - VLOG(2) << "channel"; + // If dimension does not match, switch back to per-channel scale_mode = nvinfer1::ScaleMode::kCHANNEL; break; } } - // if channel as candidate, validate it + // If the mode is per-channel, since channel dimension is assumed to be + // the third to last dimension, we need to make sure all other dimensions + // have size 1. if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != 1) - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument( + "Weight dims not compatible for channel-wise broadcast at ", + node_def.name()); } - } else { - VLOG(2) << "elementwise"; } } else if (dims_w.nbDims == 1 && dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { - // channel wise and broadcast required; - permutation_flag = true; + // Channel wise and broadcast required. We compare the last dimension of + // the tensor shape because of tensorflow default broadcasting rules. + need_to_permute = true; scale_mode = nvinfer1::ScaleMode::kCHANNEL; } else { - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument("Weight dims not compatible at ", + node_def.name()); } } + // TODO(laigd): we should add validation_only support in TransposeTensor() and + // PrepareTensorForShape(). + if (params->validation_only) return Status::OK(); - // transpose last dimension + // Transpose last dimension. std::vector permutation(dims_t.nbDims + 1); - if (permutation_flag) { - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { - // we swap the last dimension into channel for trt. - // because of tensorflow default broadcasting rules. - for (int i = 0; i < static_cast(permutation.size()); i++) { - permutation[i] = i; - } - permutation[1] = dims_t.nbDims; - permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), permutation, &tensor)); - } else { - return tensorflow::errors::InvalidArgument( - "Transpose cannot be applied, " + node_def.name()); - } + if (need_to_permute) { + // We swap the last dimension into channel for trt, because of tensorflow + // default broadcasting rules. + for (int i = 0; i < static_cast(permutation.size()); i++) { + permutation[i] = i; + } + permutation[1] = dims_t.nbDims; + permutation[dims_t.nbDims] = 1; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + const_cast(tensor), permutation, &tensor)); } - if (params->converter->is_fp16()) { + if (params->converter->precision_mode() == FP16MODE) { weights = ConvertFP32ToFP16(params->weight_store, weights); } - // prepare weights + // Prepare weights TRT_ShapedWeights shift_weights(weights.type_); TRT_ShapedWeights scale_weights(weights.type_); TRT_ShapedWeights power_weights(weights.type_); - // Maybe I should do a switch if (node_def.op() == "Sub") { if (swapped_inputs) { shift_weights = weights; @@ -1312,6 +1451,10 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, *const_cast(tensor), nvinfer1::UnaryOperation::kNEG); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Since quantization ranges are symmetric, the same range as the input + // will work for the negation of the input. + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), layer->getOutput(0)); tensor = layer->getOutput(0); } else { TRT_ShapedWeights neg_weights = @@ -1323,6 +1466,25 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { if (swapped_inputs) { + // We need to infer the quantization range for this intermediate tensor. + // + // x -> [Recip] -> 1/x -> [Scale] -> s/x + // ^ + // need range for this + // + // We have the quantization scales for x and s/x - can we divide the scale + // for s/x by s? Only if it is a scalar. + // + // Because of this issue, fall back to BinaryTensorOpTensor if we are + // doing INT8 with no calibration. There is most likely no performance + // penalty by falling back here. + if (params->converter->precision_mode() == INT8MODE && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration. Falling back to BinaryTensorOpTensor for ", + node_def.op(), ", at ", node_def.name()); + } scale_weights = weights; nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( *const_cast(tensor), @@ -1342,8 +1504,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } else if (node_def.op() == "Add") { shift_weights = weights; } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); + // This should not happen. + return errors::Unimplemented("Binary op not supported at ", node_def.op()); } nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( @@ -1353,8 +1515,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // transpose back dimension - if (permutation_flag) { + // Transpose back dimension + if (need_to_permute) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), permutation, &output_tensor)); @@ -1398,7 +1560,7 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { return tensorflow::errors::Internal( "Conv2D expects kernel of dimension 4, at: " + node_def.name()); } - if (params->converter->is_fp16()) { + if (params->converter->precision_mode() == FP16MODE) { weights_rsck = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); } @@ -1445,6 +1607,8 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) { nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions()); @@ -1486,9 +1650,9 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, params->node_def.name()); } -tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, - const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r) { +Status BinaryTensorOpTensor(OpConverterParams* params, + const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r) { const auto& node_def = params->node_def; static const std::unordered_map ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, @@ -1499,50 +1663,52 @@ tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return errors::Unimplemented("Binary op ", node_def.op(), + " not supported at: ", node_def.name()); + } - const nvinfer1::ITensor* tensor_l; - const nvinfer1::ITensor* tensor_r; - - nvinfer1::Dims dim_l; - nvinfer1::Dims dim_r; - - if (!TensorRTGetBroadcastShape(operand_l.GetTrtDims(), operand_l.is_tensor(), - operand_r.GetTrtDims(), operand_r.is_tensor(), - &dim_l, &dim_r)) { - return tensorflow::errors::InvalidArgument( - "Binary op broadcast scheme not supported by TensorRT op: " + - node_def.op() + ", at: " + node_def.name()); + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + Status status = params->converter->GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); + if (!status.ok()) { + return errors::InvalidArgument( + "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", + status.error_message()); } + if (params->validation_only) return Status::OK(); - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_l, dim_l, &tensor_l)); - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_r, dim_r, &tensor_r)); + const nvinfer1::ITensor* tensor_l = nullptr; + const nvinfer1::ITensor* tensor_r = nullptr; + status = params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, &tensor_l); + if (status.ok()) { + status = params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, &tensor_r); + } + if (!status.ok()) { + return errors::Internal("Failed to convert binary op ", node_def.name(), + ": ", status.error_message()); + } - // get trt type & shape + // Check type consistency. TFAttrs attrs(node_def); - // maybe this part has to be moved into the block of rsqrt later nvinfer1::DataType dtype = attrs.get("T"); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) + << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) + << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); - // check type consistency - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) { - return tensorflow::errors::Unimplemented( - "binary op: ", node_def.op(), " not supported at: ", node_def.name()); - } - + // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast(tensor_l), *const_cast(tensor_r), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // pass the output + // Pass the output params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1723,6 +1889,133 @@ tensorflow::Status ConvertReshape(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertExpandDims(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 2) { + return tensorflow::errors::InvalidArgument( + "Two inputs expected for ExpandDims, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "ExpandDims expects tensor for input, at ", node_def.name()); + } + if (!inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "ExpandDims expects weights for axis, at ", node_def.name()); + } + // Get input shape as vector. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Get axis to expand on. + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (weights.count() != 1) { + return tensorflow::errors::InvalidArgument( + "ExpandDims axis must be a scalar, at ", node_def.name()); + } + const int* weights_ptr = + static_cast(const_cast(weights.GetValues())); + int axis = weights_ptr[0]; + // Make sure axis is valid. + if ((axis < (-input_rank - 1)) || (axis > input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank + 1; + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Modifying batch dimension is not supported for ExpandDims, at ", + node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // ExpandDims: Insert new dim of size 1. + input_dims.insert(input_dims.begin() + axis, 1); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertSqueeze(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "One input expected for Squeeze, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Squeeze expects tensor for input, at ", node_def.name()); + } + // Get input shape. + TRT_TensorOrWeights input_tensor = inputs.at(0); + const nvinfer1::Dims dims = input_tensor.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dim back. + input_dims.insert(input_dims.begin(), -1); + const int input_rank = input_dims.size(); + // Mark axes to remove by setting them to 0. + TFAttrs attrs(node_def); + auto squeeze_dims = attrs.get>("squeeze_dims"); + if (squeeze_dims.size() == 0) { + return tensorflow::errors::Unimplemented( + "Squeeze is only implemented for explicit dims, at ", node_def.name()); + } + for (int axis : squeeze_dims) { + // Make sure axis is valid. + if ((axis < -input_rank) || (axis >= input_rank)) { + return tensorflow::errors::InvalidArgument( + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at ", + node_def.name()); + } + // Convert negative axis to corresponding positive axis. + if (axis < 0) axis += input_rank; + // Don't squeeze batch dim. + if (axis == 0) { + return tensorflow::errors::Unimplemented( + "Cannot squeeze batch dimension, at ", node_def.name()); + } + // Make sure target dimension is size 1. + if (input_dims[axis] != 1) { + return tensorflow::errors::InvalidArgument( + "Cannot squeeze a dimension which isn't size 1, at ", + node_def.name()); + } + // Mark dim for removal by setting to 0. + input_dims[axis] = 0; + } + if (params->validation_only) return Status::OK(); + + // Remove all dims which are equal to 0. + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), + input_dims.end()); + // Reshape tensor. + nvinfer1::Dims new_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims, + /*ignore_first_dim=*/true)); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + input_tensor, new_dims, &output_tensor)); + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(output_tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -1789,6 +2082,8 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } @@ -1796,6 +2091,11 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling( *const_cast(tensor), type, ksize); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // TODO(tmorris): Average pooling may not be entirely safe to infer + // quantization range through (at least forwards - backwards should be fine). + // Max pooling is okay. + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), layer->getOutput(0)); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1813,110 +2113,290 @@ tensorflow::Status ConvertPool(OpConverterParams* params) { } tensorflow::Status ConvertActivation(OpConverterParams* params) { - const nvinfer1::ITensor* tensor = params->inputs.at(0).tensor(); + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + node_def.op(), " expects one input, at ", node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + node_def.op(), " is only implemented for tensors, at ", + node_def.name()); + } + static const std::unordered_map ops{ + {"Relu", nvinfer1::ActivationType::kRELU}, + {"Sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"Tanh", nvinfer1::ActivationType::kTANH}, + }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return tensorflow::errors::Unimplemented( + "Activation op: ", node_def.op(), + " not supported at: ", node_def.name()); + } + if (params->validation_only) return tensorflow::Status::OK(); + + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = params->converter->network()->addActivation( - *const_cast(tensor), - nvinfer1::ActivationType::kRELU); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, params->node_def.name()); + *const_cast(tensor), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // Set quantization range for output of Sigmoid, Tanh. + if (node_def.op() == "Sigmoid") { + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); + } else if (node_def.op() == "Tanh") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertScale(OpConverterParams* params) { +Status ConvertQuantize(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) { + if ((inputs.size() == 0) || + (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || + (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { + return errors::InvalidArgument("Invalid number of inputs for ", + node_def.op(), ", at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { + // TensorRT will automatically quantize weights, so we will ignore ranges + // for weights. + params->outputs->push_back(inputs.at(0)); + return Status::OK(); + } + float min_range = 0.0f; + float max_range = 0.0f; + if (node_def.op() == "FakeQuantWithMinMaxArgs") { + // Get ranges via node attributes. + TFAttrs attrs(node_def); + if (attrs.count("min") == 0 || attrs.count("max") == 0) { + return errors::InvalidArgument("Min or max attribute not found for ", + node_def.op(), " at ", node_def.name()); + } + min_range = attrs.get("min"); + max_range = attrs.get("max"); + } else if (node_def.op() == "FakeQuantWithMinMaxVars" || + node_def.op() == "QuantizeAndDequantizeV2" || + node_def.op() == "QuantizeAndDequantizeV3") { + // Get ranges via inputs. + if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { + return errors::InvalidArgument("Min and max inputs for ", node_def.op(), + " must be weights not tensors, at ", + node_def.name()); + } + auto get_weights_value = [&inputs](int index) { + auto raw_weights = static_cast( + const_cast(inputs.at(index).weights().GetValues())); + return raw_weights[0]; + }; + min_range = get_weights_value(1); + max_range = get_weights_value(2); + } else { + return errors::InvalidArgument("Unknown quantization op ", node_def.op(), + ", at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // Store ranges for tensor + params->converter->ProvideQuantizationRange( + const_cast(inputs.at(0).tensor()), min_range, + max_range); + // Sometimes, TRT may not quantize a tensor, either because it chooses to + // execute a higher precision kernel or because of op fusion. In these cases, + // accuracy will suffer if the model was trained to expect quantization at + // that tensor. We should consider adding a clip(tensor, min_range, max_range) + // operation here to ensure that any arbitrarily placed quantize node will + // execute as expected. However, this will negatively affect performance. If + // users train their models in a way which models inference as close as + // possible (i.e. not quantizing in place where fusion will occur), then there + // is no problem with the current implementation. + params->outputs->push_back(inputs.at(0)); + return Status::OK(); +} + +// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports +// Relu6 natively. +tensorflow::Status ConvertRelu6(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Invalid number of inputs for Relu6, at ", node_def.name()); + } + if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( - "ConvertScale only supports tensorweight: ", node_def.name()); + "Relu6 is only implemented for tensors, not weights, at ", + node_def.name()); } + if (params->validation_only) return Status::OK(); + // *************************************************************************** + // TensorRT does not implement Relu6 natively. This function converts Relu6 op + // to available TensorRT ops: Relu6(x) = min(Relu(x), 6) + // *************************************************************************** + // Input Tensor const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->is_fp16()) { - weights = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); - } - TRT_ShapedWeights empty_weights(weights.type_); - TFAttrs attrs(node_def); + // Relu operation i.e. Relu(x) = max(0, x) + nvinfer1::IActivationLayer* relu_layer = + params->converter->network()->addActivation( + *const_cast(tensor), + nvinfer1::ActivationType::kRELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); + + // Large range of relu is problematic during quantization in INT8 precision + // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. + // TRT only uses dynamic ranges in INT8 precision mode, + // and this does not affect the FP32 path. + params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, + 6.0f); + + // Create a constant layer to store the floating point weight i.e. 6.0f This + // tensor will be broadcasted uniformly during elementwise `min` operation. + // The constant has to have the same rank as the input in order for TRT to + // broadcast + nvinfer1::Dims dims; + dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims; + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = 6.0f; + nvinfer1::IConstantLayer* const6_layer = + params->converter->network()->addConstant(dims, weights.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); + params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, + 6.0f); + + // ElementWise Min Operation + // Min op is a nop for INT8 execution path, as the input tensor + // to this layer will only have values in range [0.f, 6.0f]. + const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0); + const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0); + nvinfer1::IElementWiseLayer* relu6_layer = + params->converter->network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), + nvinfer1::ElementWiseOperation::kMIN); + TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); + nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); - const auto data_format = attrs.get("data_format"); - int channel_index; - const auto dims = tensor->getDimensions(); - if (data_format == "NHWC") { - // 1). NHWC is really N+C - channel_index = dims.nbDims - 1; // batch dimension is implicit here! - } else { - // 2). NCHW is really N+CHW - channel_index = 0; // batch dimension is implicit here! - } + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} - nvinfer1::Permutation permutation; - for (int32_t i = 0; i < dims.nbDims; ++i) { - permutation.order[i] = i; +tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return errors::InvalidArgument("Input expects tensor and weights, at ", + node_def.name()); } + if (params->validation_only) return Status::OK(); - if (channel_index >= 0) { + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + const nvinfer1::Dims original_dims = tensor->getDimensions(); + TFAttrs attrs(node_def); + const string data_format = attrs.get("data_format"); + const int channel_index = + (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); + + nvinfer1::Permutation permutation; + if (channel_index != 0) { + // Permute the dimensions so that the channel dimension is the first + // dimension. + for (int i = 0; i < original_dims.nbDims; ++i) { + permutation.order[i] = i; + } permutation.order[0] = channel_index; permutation.order[channel_index] = 0; - } else { - return tensorflow::errors::Unimplemented( - "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + VLOG(1) << "ConvertBiasAdd permutation: " + << DebugString(permutation, original_dims.nbDims); } // TensorRT addScale requires input to be of rank 3, we need to apply - // transpose as well as reshape - if (channel_index != 0 || dims.nbDims != 3) { + // transpose as well as reshape. + // TODO(laigd): this doesn't match what the TRT doc says, fix the doc? + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(tensor)); + params->converter->network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + params->converter->MarkQuantizationRangesAsInferrable( + tensor, shuffle_layer->getOutput(0)); + + // NOTE(laigd): for some reason we need to apply the reshape + // unconditionally. The default shape has nbDims==-1 and it seems the + // behavior is undefined in some cases. nvinfer1::Dims reshape_dims; reshape_dims.nbDims = 3; - reshape_dims.d[0] = 0; // 0 copy from the input - reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input - reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + // 0 means copying from input; -1 means inferring from the rest. + reshape_dims.d[0] = 0; + reshape_dims.d[1] = original_dims.nbDims >= 2 ? 0 : 1; + reshape_dims.d[2] = original_dims.nbDims >= 3 ? -1 : 1; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { - // maybe we do not need this check. concerned about TRT optimization shuffle_layer->setFirstTranspose(permutation); } - shuffle_layer->setReshapeDimensions(reshape_dims); tensor = shuffle_layer->getOutput(0); } + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (params->converter->precision_mode() == FP16MODE) { + weights = ConvertFP32ToFP16(params->weight_store, weights); + } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; if (weights.shape_.d[0] == 1) { mode = nvinfer1::ScaleMode::kUNIFORM; } + TRT_ShapedWeights empty_weights(weights.type_); nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *const_cast(tensor), mode, weights.GetTrtWeights(), - empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); + *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), + empty_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // restore transpose & reshape - if (channel_index != 0 || dims.nbDims != 3) { + // Restore transpose & reshape. + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(output_tensor)); + params->converter->network()->addShuffle(*output_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); - nvinfer1::Dims reshape_dims = dims; - int tmp = reshape_dims.d[channel_index]; - reshape_dims.d[channel_index] = reshape_dims.d[0]; - reshape_dims.d[0] = tmp; + // NOTE: for same reason as mentioned above we need to apply the reshape + // unconditionally. + nvinfer1::Dims reshape_dims = original_dims; + if (channel_index != 0) { + // NOTE: according to NVIDIA dimension types are deprecated, so we don't + // need to copy them back. + reshape_dims.d[channel_index] = original_dims.d[0]; + reshape_dims.d[0] = original_dims.d[channel_index]; + } shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { shuffle_layer->setSecondTranspose(permutation); } + params->converter->MarkQuantizationRangesAsInferrable( + output_tensor, shuffle_layer->getOutput(0)); output_tensor = shuffle_layer->getOutput(0); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } Status GetTensorDimsWithProtoShape(const Tensor& tensor, @@ -2070,32 +2550,41 @@ tensorflow::Status ConvertConst(OpConverterParams* params) { } tensorflow::Status ConvertIdentity(OpConverterParams* params) { + // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT + // 5.0, however once we know that it does it would be nice to use that + // instead. params->outputs->push_back(params->inputs.at(0)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertBinary(OpConverterParams* params) { +Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { - return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at ", node_def.name()); + return errors::InvalidArgument("Binary ops require two inputs, at ", + node_def.name()); } // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Constant folding is falled back to TensorFlow, binary op received " "both input as constant at: ", node_def.name()); } - // Try to convert into Scale layer first (for better performance) + // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with + // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For + // now, the performance will be slightly better with IScaleLayer because it + // can be fused in more situations. However, most of the benefits of + // IScaleLayer are when the layer performs both a shift and a scale, which we + // don't do except for convolutions. + // + // Try to convert into Scale layer first (for better performance). // Since scale layer supports restricted broadcast policy and op types, we // allow failure and try to handle it through Elementwise op - // (BinaryTensorOpTensor) - Status status = tensorflow::Status::OK(); + // (BinaryTensorOpTensor). + Status status = Status::OK(); if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), inputs.at(1).weights(), false); @@ -2103,7 +2592,10 @@ tensorflow::Status ConvertBinary(OpConverterParams* params) { status = BinaryTensorOpWeight(params, inputs.at(1).tensor(), inputs.at(0).weights(), true); } + // If both input are tensors, or one of them is weights but the conversion + // above failed, try the conversion using BinaryTensorOpTensor. if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { + if (!status.ok()) VLOG(1) << status; status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); } return status; @@ -2133,6 +2625,20 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer; if (node_def.op() == "Rsqrt") { + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here + if (params->converter->precision_mode() == INT8MODE && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); + } layer = params->converter->network()->addUnary( *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); @@ -2156,6 +2662,48 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertSquare(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 1) { + return tensorflow::errors::InvalidArgument("Square expects one input, at ", + node_def.name()); + } + if (inputs.at(0).is_weights()) { + return tensorflow::errors::Unimplemented( + "Square is only implemented for tensors, at ", node_def.name()); + } + if (params->validation_only) return Status::OK(); + + // Constant 2 with same rank as input + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + for (int i = 0; i < dims.nbDims; i++) { + dims.d[i] = 1; + } + TRT_ShapedWeights weights = params->weight_store->GetTempWeights( + tensorflow::DataType::DT_FLOAT, dims); + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = 2.f; + nvinfer1::IConstantLayer* const2_layer = + params->converter->network()->addConstant(dims, weights.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(const2_layer, node_def.name()); + + // ElementWise Pow Operation + const nvinfer1::ITensor* tensor_l = inputs.at(0).tensor(); + const nvinfer1::ITensor* tensor_r = const2_layer->getOutput(0); + nvinfer1::IElementWiseLayer* layer = + params->converter->network()->addElementWise( + *const_cast(tensor_l), + *const_cast(tensor_r), + nvinfer1::ElementWiseOperation::kPOW); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertReduce(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -2692,6 +3240,8 @@ tensorflow::Status ConvertSoftmax(OpConverterParams* params) { layer->setAxes(1 << (nbDims - 1)); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + // Quantization range for SoftMax is always (0, 1) + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -2732,40 +3282,54 @@ tensorflow::Status ConvertTopK(OpConverterParams* params) { return tensorflow::Status::OK(); } -void TrtNodeValidator::RegisterOpValidators() { +static void RegisterValidatableOpConverters( + std::unordered_map* registration) { // TODO(laigd): support all op types. - op_validators_["Const"] = ConvertConst; - op_validators_["Transpose"] = ConvertTranspose; - op_validators_["Reshape"] = ConvertReshape; - op_validators_["MatMul"] = ConvertMatMul; + (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["Const"] = ConvertConst; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["Reshape"] = ConvertReshape; + (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Square"] = ConvertSquare; + (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["Squeeze"] = ConvertSqueeze; + + for (auto quantization_op_type : + {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) { + (*registration)[quantization_op_type] = ConvertQuantize; + } + for (auto binary_op_type : + {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) { + (*registration)[binary_op_type] = ConvertBinary; + } + for (auto activation_op_type : {"Relu", "Sigmoid", "Tanh"}) { + (*registration)[activation_op_type] = ConvertActivation; + } +} + +void TrtNodeValidator::RegisterOpValidators() { + RegisterValidatableOpConverters(&op_validators_); } void Converter::RegisterOpConverters() { - // vgg_16 slim implementation + RegisterValidatableOpConverters(&op_registry_); + op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - op_registry_["BiasAdd"] = ConvertScale; - op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - // resnet_50_v1 slim implementation - op_registry_["Add"] = ConvertBinary; - op_registry_["Mul"] = ConvertBinary; - op_registry_["Sub"] = ConvertBinary; op_registry_["Pad"] = ConvertPad; op_registry_["ConcatV2"] = ConvertConcat; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Div"] = ConvertBinary; - op_registry_["RealDiv"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; @@ -2774,18 +3338,12 @@ void Converter::RegisterOpConverters() { op_registry_["Abs"] = ConvertUnary; op_registry_["Neg"] = ConvertUnary; - op_registry_["Transpose"] = ConvertTranspose; - op_registry_["Reshape"] = ConvertReshape; - op_registry_["Sum"] = ConvertReduce; op_registry_["Prod"] = ConvertReduce; op_registry_["Max"] = ConvertReduce; op_registry_["Min"] = ConvertReduce; op_registry_["Mean"] = ConvertReduce; - op_registry_["Maximum"] = ConvertBinary; - op_registry_["Minimum"] = ConvertBinary; op_registry_["Softmax"] = ConvertSoftmax; - op_registry_["MatMul"] = ConvertMatMul; op_registry_["BatchMatMul"] = ConvertBatchMatMul; op_registry_["TopKV2"] = ConvertTopK; @@ -2798,7 +3356,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully) { engine->reset(); if (convert_successfully) *convert_successfully = false; @@ -2813,7 +3371,11 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setHalf2Mode(true); } else if (precision_mode == INT8MODE) { builder->setInt8Mode(true); - builder->setInt8Calibrator(calibrator); + if (use_calibration) { + builder->setInt8Calibrator(calibrator); + } else { + builder->setInt8Calibrator(nullptr); + } } // Create the network. @@ -2826,7 +3388,7 @@ tensorflow::Status ConvertGraphDefToEngine( // Build the network VLOG(1) << "Starting engine conversion "; - Converter converter(trt_network.get(), precision_mode == FP16MODE); + Converter converter(trt_network.get(), precision_mode, use_calibration); std::vector> output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { @@ -2882,6 +3444,9 @@ tensorflow::Status ConvertGraphDefToEngine( TF_RETURN_IF_ERROR(converter.RenameAndMarkOutputTensors(output_tensors)); if (convert_successfully) *convert_successfully = true; + // Apply user provided quantization ranges to tensors + converter.MaybeApplyQuantizationRanges(); + // Build the engine. VLOG(1) << "Starting engine creation"; engine->reset(builder->buildCudaEngine(*converter.network())); @@ -3026,7 +3591,8 @@ tensorflow::Status ConvertSegmentToGraphDef( } } *common_scope = local_scope; - VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; + VLOG(1) << "Converted TensorRT candidate segment @scope '" << local_scope + << "' to a GraphDef"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 5cc28b33e7f2c56d2f281d24e8390d253a8228f5..54e19b73957bccdae2b23bd3556de9ad00b864e5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -92,7 +92,8 @@ struct EngineInfo { EngineInfo() : engine_type(EngineType::TRTStatic), max_workspace_size_bytes(0), - precision_mode(FP32MODE) {} + precision_mode(FP32MODE), + use_calibration(true) {} string engine_name; string device; @@ -109,6 +110,7 @@ struct EngineInfo { int maximum_cached_engines; std::vector cached_engine_batches; int precision_mode; + bool use_calibration; }; // Constructs a graphdef from the segment in the given graph. Adds placeholder @@ -145,7 +147,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully); // Helper class for the segmenter to determine whether an output edge from the @@ -392,7 +394,8 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16); + Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, + bool use_calibration); ////////////////////////////////////////////////////////////////////////////// // Methods used by the TRT engine builder to build a TRT network from a TF @@ -422,8 +425,27 @@ class Converter { // to add TRT layers. nvinfer1::INetworkDefinition* network() { return trt_network_; } - // Is the converter operating in fp16 mode? - bool is_fp16() const { return is_fp16_; } + // What precision are we targeting? + int precision_mode() const { return precision_mode_; } + + // Calibration will be or was previously performed on this network? + bool use_calibration() const { return use_calibration_; } + + // This should be called on the inputs and outputs of any layer we create + // where we know that the quantization range does not change during that + // operation. (e.g. Reshape, Transpose, Identity, MaxPool). + void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input, + nvinfer1::ITensor* output); + + // This function should be called when we know the quantization range of a + // tensor, either from a quantize/dequantize node or when the output is a + // fixed range (e.g. SoftMax, Relu6, Sigmoid). + void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range, + float max_range); + + // Should be called when full TRT network has been constructed and before + // building the engine. + void MaybeApplyQuantizationRanges(); // Below are helper methods for op converters to add different layers to the // TRT network. @@ -440,6 +462,13 @@ class Converter { const nvinfer1::Dims& dims, const nvinfer1::ITensor** tensor); + // Return OK if the broadcast scheme is supported and compute the shapes after + // broadcasting. + Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const; + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -457,6 +486,12 @@ class Converter { void RegisterOpConverters(); + void PropagateQuantizationRanges(); + + // Gets the min and max value in a TRT_ShapedWeights + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const; + // Registered op converters by op type. std::unordered_map op_registry_; @@ -472,7 +507,25 @@ class Converter { // Store the weights added during construction of trt_network_. TrtWeightStore weight_store_; - const bool is_fp16_; + // During conversion, this table is populated with quantization ranges per + // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT + // quantization ranges. Since TRT only supports symmetric ranges, we will + // store the range as a single float = max(abs(min_range), abs(max_range)). + // Range refers to the floating point values, e.g. min_range = 0.0f, max_range + // = 6.0f for Relu6. + std::unordered_map quantization_ranges_; + + // Edges where quantization ranges can be inferred (copied) across ops - from + // first tensor to second tensor. PropagateQuantizationRanges() will propagate + // known ranges from quantization_ranges_ across these edges, adding the new + // ranges to quantization_ranges_ so that they can be applied in + // MaybeApplyQuantizationRanges(). + std::vector> + quantization_infer_; + + const int precision_mode_; + + const bool use_calibration_; // Batch size of inputs to trt_network_ added by AddInputTensor(). During // network construction it will update this, use it to verify the batch diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index c3a39395f3a99f3e471e09688a11cc0ebba61ff4..c37a43dd5def9daf3c5d70720c6db2aab20db077 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -35,7 +35,10 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/public/session.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -47,7 +50,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrCat; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -69,6 +74,32 @@ nvinfer1::Dims GetTestDims(const std::vector& d) { return dims; } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { + switch (tf_dtype) { + case DT_FLOAT: + return nvinfer1::DataType::kFLOAT; + case DT_HALF: + return nvinfer1::DataType::kHALF; + case DT_INT32: + return nvinfer1::DataType::kINT32; + default: + QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); + } +} + +DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return DT_FLOAT; + case nvinfer1::DataType::kHALF: + return DT_HALF; + case nvinfer1::DataType::kINT32: + return DT_INT32; + default: + QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); + } +} + NodeDef MakeNodeDef(const string& name, const string& op, const std::vector& inputs) { NodeDef node_def; @@ -111,6 +142,35 @@ bool TrtDimsEqualsArray(const std::vector& lhs, return TrtDimsEquals(GetTestDims(lhs), rhs); } +// TODO(laigd): define a parameterized matcher that can compare against the +// vector. +void ExpectTrtDimsEqualsArray(const std::vector& lhs, + const nvinfer1::Dims& rhs) { + EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs)) + << "expected: " << DebugString(GetTestDims(lhs)) << "\n" + << " actual: " << DebugString(rhs); +} + +template +void ExpectArrayNear(const std::vector& lhs, const std::vector& rhs) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_FLOAT_EQ(lhs[i], rhs[i]); + } +} + +// Eigen::half cannot implicitly convert to float which is required for +// EXPECT_FLOAT_EQ. +template <> +void ExpectArrayNear(const std::vector& lhs, + const std::vector& rhs) { + ASSERT_EQ(lhs.size(), rhs.size()); + for (int i = 0; i < lhs.size(); i++) { + EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]), + Eigen::half_impl::half_to_float(rhs[i])); + } +} + bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) { return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && @@ -121,8 +181,7 @@ template void ValidateWeights(const TRT_ShapedWeights& weights, const std::vector& expected_dims, const std::vector& expected_value) { - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) - << weights.DebugString(); + ExpectTrtDimsEqualsArray(expected_dims, weights.shape_); ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); const T* actual_values = static_cast(weights.GetValues()); for (int i = 0; i < expected_value.size(); ++i) { @@ -133,11 +192,12 @@ void ValidateWeights(const TRT_ShapedWeights& weights, // Fake ITensor implementation for testing purposes. class FakeITensor : public nvinfer1::ITensor { public: - FakeITensor() {} + FakeITensor() : dynamic_range_(0.0f) {} - FakeITensor(const nvinfer1::Dims& dims) : dims_(dims) {} + FakeITensor(const nvinfer1::Dims& dims) : dims_(dims), dynamic_range_(0.0f) {} - FakeITensor(const std::vector& dims) : dims_(GetTestDims(dims)) {} + FakeITensor(const std::vector& dims) + : dims_(GetTestDims(dims)), dynamic_range_(0.0f) {} void setName(const char* name) override { name_ = name; } @@ -166,7 +226,12 @@ class FakeITensor : public nvinfer1::ITensor { } #if NV_TENSORRT_MAJOR >= 5 - bool setDynamicRange(float min, float max) override {} + bool setDynamicRange(float min, float max) override { + dynamic_range_ = std::max(std::abs(min), std::abs(max)); + return true; + } + + float getDynamicRange() const override { return dynamic_range_; } #endif private: @@ -174,6 +239,7 @@ class FakeITensor : public nvinfer1::ITensor { nvinfer1::Dims dims_; nvinfer1::DataType type_; nvinfer1::TensorLocation location_; + float dynamic_range_; }; TEST(TRT_ShapedWeights_Test, Basic) { @@ -265,9 +331,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(1, ptr->batch_size()); } EXPECT_EQ(&itensor, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } } @@ -286,9 +350,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_weights()); EXPECT_EQ(1, ptr->batch_size()); EXPECT_NE(nullptr, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } // Test constructor with TRT_ShapedWeights argument. @@ -305,9 +367,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { nvinfer1::Dims dims; dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } } @@ -341,34 +401,50 @@ TEST_F(ValidatorTest, ConvertToTensorOrWeights) { graph_properties, &output)); ValidateWeights(output.weights(), {2}, {1.0, 2.0}); } - // Convert non-Const. We test the case where the non-batch dimemsion is - // unknown as well, to make sure the validator allows that. - for (const int32 non_batch_dim : {-1, 2}) { - const int32 batch_size = 12; + // Helper method to run ConvertToTensorOrWeights() with predefined parameters. + auto convert_to_tensor_or_weights = [this](const std::vector& dims, + TRT_TensorOrWeights* output) { Scope s = Scope::NewRootScope(); - ops::Placeholder::Attrs attrs; - TF_EXPECT_OK(TensorShapeUtils::MakeShape( - std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims}); auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); auto add = ops::Add(s.WithOpName("add"), feed, feed); grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - - auto& node_def = add.operation.node()->def(); + const NodeDef& node_def = add.operation.node()->def(); + return this->ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, output); + }; + // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1. + { TRT_TensorOrWeights output; - ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, - graph_properties, &output)); + ExpectStatus( + convert_to_tensor_or_weights( + std::vector(nvinfer1::Dims::MAX_DIMS + 2, 1), &output), + error::OUT_OF_RANGE, "Input tensor rank is greater than 9"); + } + // Convert non-Const with #dims < 2. + { + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT, + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output)); EXPECT_EQ(true, output.is_tensor()); EXPECT_EQ(batch_size, output.batch_size()); EXPECT_NE(nullptr, output.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) - << "- expected: {" << non_batch_dim << "} \n vs\n" - << "- actual: " << DebugString(output.GetTrtDims()); + ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()); } } @@ -405,7 +481,9 @@ class ConverterTest : public ::testing::Test { ConverterTest() { builder_.reset(nvinfer1::createInferBuilder(logger_)); network_.reset(builder_->createNetwork()); - converter_.reset(new Converter(network_.get(), /*fp16=*/false)); + converter_.reset(new Converter(network_.get(), + /*precision_mode=*/FP32MODE, + /*use_calibration=*/false)); weight_store_ = &converter_->weight_store_; } @@ -432,8 +510,21 @@ class ConverterTest : public ::testing::Test { return converter_->GetInputs(node_def, inputs); } + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const { + return converter_->GetWeightRange(weights, out_min, out_max); + } + + void PropagateQuantizationRanges() { + converter_->PropagateQuantizationRanges(); + } + int batch_size() const { return converter_->batch_size_; } + std::unordered_map& quantization_ranges() { + return converter_->quantization_ranges_; + } + private: Logger logger_; // These members are ordered in a way such that the destruction order is: @@ -504,9 +595,9 @@ TEST_F(ConverterTest, AddAndGetInputs) { EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions())); + ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()); } TEST_F(ConverterTest, RenameAndMarkOutputTensors) { @@ -552,7 +643,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}})); EXPECT_EQ(2, output_tensors.size()); for (auto output_tensor : output_tensors) { - EXPECT_TRUE(TrtDimsEqualsArray({2, 1}, output_tensor->getDimensions())); + ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions()); } EXPECT_EQ("my_output", string(output_tensors[0]->getName())); EXPECT_EQ("my_output_1", string(output_tensors[1]->getName())); @@ -577,8 +668,7 @@ TEST_F(ConverterTest, TransposeTensor) { // OK. TF_EXPECT_OK( converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { @@ -590,7 +680,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Shape size doesn't match. ExpectStatus(converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), &output_tensor), - error::INVALID_ARGUMENT, "Reshape shapes are not compatible."); + error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); // TODO(aaroey): we should check the case where uninferred dimensions are not // an exact divisor of input dim ensions, e.g. for dims {-1, 7}. @@ -598,14 +688,12 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Infer shape, ok. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({-1, 2}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({15, 2}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); // Regular shape. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Weights) { @@ -615,8 +703,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { const nvinfer1::ITensor* output_tensor = nullptr; TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -656,6 +743,178 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { "tensor/weights my_tensor already exist"); } +template +void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { + TRT_ShapedWeights weights = + weight_store->GetTempWeights(DataTypeToEnum::v(), GetTestDims({2, 3})); + const std::vector values = {T(3), T(1), T(2), T(6), T(5), T(4)}; + memcpy(const_cast(weights.GetValues()), values.data(), + weights.size_bytes()); + + float out_min = 0.0f; + float out_max = 0.0f; + TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max)); + EXPECT_EQ(1.0f, out_min); + EXPECT_EQ(6.0f, out_max); +} + +TEST_F(ConverterTest, GetWeightRange) { + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); +} + +TEST_F(ConverterTest, ProvideQuantizationRange) { + FakeITensor fake_tensor; + // Assymetric range + converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f); + EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]); + converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f); + EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]); + // Symmetric range + converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f); + EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]); +} + +TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { + // input -> infer1 -> infer2 -> infer3 + FakeITensor input, infer_1, infer_2, infer_3; + FakeITensor not_infer; + Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + /*use_calibration=*/true); + int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); + int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); + int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3); + + // Input range should be inferred along the chain and applied to tensors. + int8_converter.MaybeApplyQuantizationRanges(); +#if NV_TENSORRT_MAJOR >= 5 + EXPECT_EQ(input.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_2.getDynamicRange(), 5.0f); + EXPECT_EQ(infer_3.getDynamicRange(), 5.0f); + EXPECT_EQ(not_infer.getDynamicRange(), 100.0f); +#endif +} + +TEST_F(ConverterTest, PropagateQuantizationRanges) { + // infer0 <-> infer1 <-> infer2 <-> infer3 + // | + // infer4 <-> infer5 + FakeITensor infer[6]; + FakeITensor not_infer; + converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f); + converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]); + + // Input range should be inferred along the chain. + PropagateQuantizationRanges(); + auto ranges = quantization_ranges(); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(5.0f, ranges[&infer[i]]); + } + EXPECT_EQ(ranges.count(¬_infer), 0); +} + +TEST_F(ConverterTest, GetTrtBroadcastShape) { + const bool kIsTensor = true; + const bool kIsNotTensor = false; + auto symmetric_test = [this](const std::vector& operand_1_shape, + const std::vector& operand_2_shape, + const bool operand_1_is_tensor, + const bool operand_2_is_tensor, + const std::vector& expected_operand_1_shape, + const std::vector& expected_operand_2_shape, + error::Code expected_code = error::OK, + const char* expected_error_msg_substr = nullptr, + const int operand_1_batch_size = -1, + const int operand_2_batch_size = -1) { + auto create_tensor_or_weights = [](const std::vector& shape, + bool is_tensor, int batch_size = -1) { + if (is_tensor) { + return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT, + GetTestDims(shape), batch_size}; + } + TRT_ShapedWeights weights; + weights.shape_ = GetTestDims(shape); + return TRT_TensorOrWeights(weights); + }; + + nvinfer1::Dims operand_1_new_dims, operand_2_new_dims; + TRT_TensorOrWeights operand_1 = create_tensor_or_weights( + operand_1_shape, operand_1_is_tensor, operand_1_batch_size); + TRT_TensorOrWeights operand_2 = create_tensor_or_weights( + operand_2_shape, operand_2_is_tensor, operand_2_batch_size); + + // operand_1 broadcast operand_2 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + // operand_2 broadcast operand_1 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + }; + + // Both inputs are weights. + symmetric_test( + {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, + "Broadcasting requires at least one of the operands be tensors"); + + // One tensor and one weights. + symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2}); + symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2}); + symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1}); + symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {1, 2, 3}); + symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme", + /*operand_1_batch_size=*/2); + symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); + + // Both inputs are tensors. + symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -684,15 +943,21 @@ class OpConverterTest : public ::testing::Test { // Reset the validator and converter. validator_.reset(new TrtNodeValidator); - converter_.reset(new Converter(network_.get(), /*fp16=*/false)); + converter_.reset(new Converter(network_.get(), + /*precision_mode=*/FP32MODE, + /*use_calibration=*/false)); // Reset other related artifacts. scope_ = Scope::NewRootScope(); validator_inputs_.clear(); } - void BuildAndRun(const char* input_name, const std::vector& input_data, - const char* output_name, std::vector* output_data) { + // TODO(laigd): test fp16 and int8 support. + template + void BuildAndRun( + const std::vector>>& + input_data, + const char* output_name, std::vector* output_data) { // Mark the output tensor as TRT engine output. TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( {{string(output_name), string(output_name)}})); @@ -703,25 +968,33 @@ class OpConverterTest : public ::testing::Test { CHECK_NOTNULL(engine_.get()); // Execute the TRT engine. - const int input_size = input_data.size() * sizeof(float); - const int output_size = output_data->size() * sizeof(float); - const int input_index = engine_->getBindingIndex(input_name); - const int output_index = engine_->getBindingIndex(output_name); + ASSERT_LE(input_data.size() + 1, 3); + void* buffers[3]; + for (const auto name_and_data : input_data) { + const int input_size = name_and_data.second.size() * sizeof(T); + const int input_index = engine_->getBindingIndex(name_and_data.first); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + ASSERT_EQ( + 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), + input_size, cudaMemcpyHostToDevice, stream_)); + } - ASSERT_EQ(engine_->getNbBindings(), 2); - void* buffers[2]; - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + const int output_size = output_data->size() * sizeof(T); + const int output_index = engine_->getBindingIndex(output_name); ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input_data.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + + ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); + TrtUniquePtrType execution_context( engine_->createExecutionContext()); execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], output_size, cudaMemcpyDeviceToHost, stream_)); cudaStreamSynchronize(stream_); - ASSERT_EQ(0, cudaFree(buffers[input_index])); - ASSERT_EQ(0, cudaFree(buffers[output_index])); + + for (int i = 0; i < input_data.size() + 1; ++i) { + ASSERT_EQ(0, cudaFree(buffers[i])); + } } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -736,18 +1009,7 @@ class OpConverterTest : public ::testing::Test { void AddTestTensor( const char* name, const std::vector& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { - DataType tf_dtype = DT_FLOAT; - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - tf_dtype = DT_FLOAT; - break; - case nvinfer1::DataType::kINT32: - tf_dtype = DT_INT32; - break; - default: - ASSERT_TRUE(false) << "Unexpected data type " - << static_cast(trt_dtype); - } + DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); attrs.shape_.InsertDim(0, batch_size); @@ -826,6 +1088,11 @@ class OpConverterTest : public ::testing::Test { } } + // Expose quantization_ranges_ for tests + std::unordered_map& quantization_ranges() { + return converter_->quantization_ranges_; + } + std::unique_ptr converter_; std::unique_ptr validator_; @@ -835,6 +1102,11 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; + // Used to create placeholders with shape and data type information. The + // created placeholders will be used as inputs to the node to be verified, + // thus we need the shape and data type information to get a non-empty + // GraphProperties. + // TODO(laigd): consider use this Scope to create the NodeDef to verify. Scope scope_; std::unordered_map validator_inputs_; }; @@ -958,15 +1230,15 @@ TEST_F(OpConverterTest, ConvertTranspose) { Reset(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("weights", {4}, {0, 3, 1, 2}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_transpose", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1048,15 +1320,15 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); AddTestWeights("weights", {4}, ok_params[i].shape); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_reshape", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1070,15 +1342,14 @@ TEST_F(OpConverterTest, ConvertMatMul) { "Input expects tensor and weights, at my_matmul"); } - // Get the NodeDef for Reshape. + // Get the NodeDef for MatMul. auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, bool transpose_b) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), dtype); auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); - ops::MatMul::Attrs matmul_attrs; - matmul_attrs.transpose_a_ = transpose_a; - matmul_attrs.transpose_b_ = transpose_b; + const auto matmul_attrs = + ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b); auto matmul = ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); return matmul.operation.node()->def(); @@ -1094,45 +1365,990 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::UNIMPLEMENTED, "Data type is not supported, for node my_matmul got int32"); } - { - // transpose_a is set. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected"); + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); } } - { - // OK. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunConversion(node_def); +} + +template +void TestConvertBiasAdd(OpConverterTest* test) { + // Get the NodeDef for BiasAdd. + auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format); + auto biasadd = + ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs); + return biasadd.operation.node()->def(); + }; + + typedef typename EnumToDataType::Type CType; + for (const string& data_format : {"NHWC", "NCHW"}) { + for (const int trt_input_rank : {1, 2, 3, 4}) { + test->Reset(); + NodeDef node_def = get_biasadd_nodedef(data_format); + + // Add input, dims_array will be like {2, 1, ..., 1, 3} + std::vector dims_array(trt_input_rank, 1); + if (trt_input_rank == 1) { + dims_array[0] = (data_format == "NHWC" ? 3 : 2); + } else { + dims_array[0] = 2; + dims_array[trt_input_rank - 1] = 3; + } + test->AddTestTensor("input", dims_array, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + + // Add bias weights. + const int channel_size = (data_format == "NHWC" ? 3 : 2); + std::vector bias(channel_size); + for (int i = 0; i < channel_size; ++i) { + bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...} + } + test->AddTestWeights("weights", {channel_size}, bias); + + // Run the conversion. + test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) - << output.DebugString(); - - std::vector output_data(2); - BuildAndRun("input", {0, 1}, "my_matmul", &output_data); - if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); + ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()); + + // Build and run the engine. + const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); + ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), + num_input); + std::vector output_data(num_input); + test->BuildAndRun( + {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", + &output_data); + if (trt_input_rank == 1) { + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + } } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), + CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), + CType(2), CType(2), CType(2))); + } } } } } +TEST_F(OpConverterTest, ConvertBiasAdd) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Input expects tensor and weights, at my_biasadd"); + } + + // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test + // DT_INT32 type here. + TestConvertBiasAdd(this); + TestConvertBiasAdd(this); +} + +template +NodeDef GetBinaryOpNodeDef(const string& input_name_l, + const string& input_name_r, DataType dtype) { + Scope s = Scope::NewRootScope(); + auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype); + auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype); + auto op = OpType(s.WithOpName("my_binary"), input_l, input_r); + return op.operation.node()->def(); +} + +void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { + bool element_wise_layer_found = false; + bool scale_layer_found = false; + for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) { + nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i); + if (dynamic_cast(layer)) { + scale_layer_found = true; + } else if (dynamic_cast(layer)) { + element_wise_layer_found = true; + } + } + EXPECT_EQ(expect_scale_layer, scale_layer_found); + EXPECT_NE(expect_scale_layer, element_wise_layer_found); +} + +template +void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + for (auto swap_inputs : {false, true}) { + test->Reset(); + NodeDef node_def; + if (swap_inputs) { + node_def = GetBinaryOpNodeDef("weights", "input", dtype); + } else { + node_def = GetBinaryOpNodeDef("input", "weights", dtype); + } + + const std::vector operand1{CType(3), CType(7.5)}; + const std::vector operand2{CType(2), CType(3)}; + + // It requires the dims to be at least of rank 3 to apply an IScaleLayer. + test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", /*dims=*/{1, 1, 2}, + /*values=*/swap_inputs ? operand1 : operand2); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + test->BuildAndRun( + {{"input", + /*input_data=*/swap_inputs ? operand2 : operand1}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else { + ASSERT_TRUE(false); + } + } +} + +template +void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10), CType(20)}; + // There are two types of valid dim pairs which requires channel-wise + // broadcasting: + // - input dims (X Y Z) vs weights dims (X 1 1) + // - input dims (X Y Z) vs weights dims (Z) + // Here X=Z=2 and Y=1. + for (auto weights_dims : std::vector>{{2, 1, 1}, {2}}) { + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", weights_dims, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + if (weights_dims.size() == 1) { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(22), CType(13), CType(24))); + } else { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(23), CType(24))); + } + } +} + +template +void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10)}; + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", {1, 1, 1, 1}, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(13), CType(14))); +} + +template +void TestBinaryTensorOpWeightFallback(OpConverterTest* test, + const std::vector& input_dims, + const std::vector& weights_dims, + error::Code code = error::OK, + const char* error_msg_substr = nullptr, + const int input_batch_size = 1) { + const DataType dtype = DT_FLOAT; + typedef typename EnumToDataType::Type CType; + const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims)); + const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims)); + + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size, + TfDataTypeToTrt(dtype)); + test->AddTestWeights( + "weights", /*dims=*/weights_dims, + /*values=*/std::vector(num_weights, CType(1))); + test->RunValidationAndConversion(node_def, code, error_msg_substr); + if (code != error::OK) return; + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + + // Check the dims of the output ITensor. + std::vector expected_output_dims = input_dims; + for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1; + i >= 0 && j >= 0; --i, --j) { + if (expected_output_dims[i] == 1) { + expected_output_dims[i] = weights_dims[j]; + } + } + ExpectTrtDimsEqualsArray(expected_output_dims, + output.tensor()->getDimensions()); + + // Check the result of running the engine. + const int expected_num_outputs = + TrtDimsNumElements(GetTestDims(expected_output_dims)); + std::vector output_data(expected_num_outputs); + test->BuildAndRun( + {{"input", + /*input_data=*/std::vector(num_inputs, CType(2))}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(3)))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(1)))); + } else { + ASSERT_TRUE(false); + } +} + +template +void TestBinaryTensorOpTensor(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + // Check output dims. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + // After broadcasting first input becomes {3, 6, 3, 6} and second input + // becomes {2, 3, 2, 3}. + test->BuildAndRun( + {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, + ElementsAre(CType(5), CType(8), CType(6), CType(9))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, + ElementsAre(CType(1), CType(4), CType(0), CType(3))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, + ElementsAre(CType(6), CType(12), CType(9), CType(18))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, + ElementsAre(CType(2), CType(2), CType(3), CType(3))); + } else if (node_def.op() == "Maximum") { + EXPECT_THAT(output_data, + ElementsAre(CType(3), CType(6), CType(3), CType(6))); + } else { + ASSERT_TRUE(false); + } +} + +TEST_F(OpConverterTest, ConvertBinary) { + // Input size doesn't match, should fail. + for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Binary ops require two inputs, at my_add"); + } + { + // Both inputs are weights. + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + AddTestWeights("weights1", {1}, {1}); + AddTestWeights("weights2", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: my_add"); + } + + // Test BinaryTensorOpWeight() without broadcasting. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#if 0 + // TODO(b/119560144): it doesn't support FP16 constants and the following test + // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#endif + + // Test BinaryTensorOpWeight() with channel-wise broadcasting. + TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); + + // Test BinaryTensorOpWeight() with uniformly broadcasting. + TestBinaryTensorOpWeightWithUniformlyBroadcast(this); + + // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). + // Unsupported op. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1}); + // Rank of input tensor dimension <3. + TestBinaryTensorOpWeightFallback(this, {1, 1}, {1}); + // Broadcast on batch dimension, should fail. + TestBinaryTensorOpWeightFallback( + this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, + "Unsupported binary op broadcast scheme for op my_binary", + /*input_batch_size=*/2); + // Incompatible dims with per-channel mode. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1, 2, 1}); + // Incompatible dims. + TestBinaryTensorOpWeightFallback(this, {1, 2, 1}, {2}); + + // Test BinaryTensorOpTensor() with broadcasting. + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); +} + +TEST_F(OpConverterTest, ConvertQuantize) { + for (const string& op : + {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", + "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + StrCat("Invalid number of inputs for ", op, ", at my_quantize") + .c_str()); + } + { + // FakeQuantWithMinMaxArgs attributes are empty, should fail. + NodeDef node_def = + MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min or max attribute not found for FakeQuantWithMinMaxArgs " + "at my_quantize"); + } + { + // FakeQuantWithMinMaxArgs ranges set via attributes, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"), + input, quantize_attrs); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // FakeQuantWithMinMaxVars ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::FakeQuantWithMinMaxVars( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // QuantizeAndDequantizeV2 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights_min", {1}); + AddTestTensor("weights_max", {1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " + "tensors, at my_quantize"); + } + { + // QuantizeAndDequantizeV3 ranges set via inputs, ok. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32); + auto quantize = ops::QuantizeAndDequantizeV3( + s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights_min", {1}, {-6.0f}); + AddTestWeights("weights_max", {1}, {6.0f}); + AddTestWeights("num_bits", {1}, {8}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } +} + +TEST_F(OpConverterTest, ConvertRelu6) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Invalid number of inputs for Relu6, at my_relu6"); + } + + // Get the NodeDef for Relu6. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); + const NodeDef node_def = relu6.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1}, {1.0f}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Relu6 is only implemented for tensors, not weights, at my_relu6"); + } + { + // Clip tensor values and set quantization ranges, ok. + Reset(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); + EXPECT_TRUE(output.is_tensor()); + auto ranges = quantization_ranges(); + EXPECT_EQ(ranges[output.tensor()], 6.0f); + + std::vector output_data(6); + BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", + &output_data); + EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); + } +} + +template +void TestConvertSquare(OpConverterTest* test) { + test->Reset(); + typedef typename EnumToDataType::Type CType; + + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto square = ops::Square(s.WithOpName("my_square"), input); + NodeDef node_def = square.operation.node()->def(); + + test->AddTestTensor("input", {1, 20}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); + + const int num_inputs = 20; + std::vector input_data(num_inputs); + std::vector expected_output_data(num_inputs); + for (int i = 0; i < 20; i++) { + const CType value = CType(i - 9); + input_data[i] = value; + expected_output_data[i] = value * value; + } + std::vector output_data(num_inputs); + test->BuildAndRun({{"input", input_data}}, "my_square", &output_data); + ExpectArrayNear(expected_output_data, output_data); +} + +TEST_F(OpConverterTest, ConvertSquare) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_square", "Square", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Square expects one input, at my_square"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto square = ops::Square(s.WithOpName("my_square"), input); + NodeDef node_def = square.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Square is only implemented for tensors, at my_square"); + } + + // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't + // test DT_INT32 type here. + TestConvertSquare(this); + // TODO(tmorris): Looks like there may be a bug with this layer for FP16 + // inputs. Disabling for now. + // TestConvertSquare(this); +} + +TEST_F(OpConverterTest, ConvertActivation) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_act", "Relu", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Relu expects one input, at my_act"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto relu = ops::Relu(s.WithOpName("my_act"), input); + const NodeDef& node_def = relu.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Relu is only implemented for tensors, at my_act"); + } + + // Get nodedef for activation layer. + auto get_act_nodedef = [](string op_name) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + if (op_name == "Relu") { + auto act = ops::Relu(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } else if (op_name == "Sigmoid") { + auto act = ops::Sigmoid(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } else if (op_name == "Tanh") { + auto act = ops::Tanh(s.WithOpName("my_act"), input); + return act.operation.node()->def(); + } + EXPECT_TRUE(false); + return NodeDef(); + }; + // Get expected output for activation layer. + auto get_act_output = [](string op_name, float input) -> float { + if (op_name == "Relu") { + return (input > 0.0f) ? input : 0.0f; + } else if (op_name == "Sigmoid") { + return 1.0f / (1.0f + std::exp(-input)); + } else if (op_name == "Tanh") { + return std::tanh(input); + } + EXPECT_TRUE(false); + return 0; + }; + + // Ok. + for (string op_name : {"Relu", "Sigmoid", "Tanh"}) { + Reset(); + NodeDef node_def = get_act_nodedef(op_name); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + + const std::vector input_data = {-100, -2, -1, 0, 1, 100}; + std::vector output_data(6); + BuildAndRun({{"input", input_data}}, "my_act", &output_data); + for (int i = 0; i < input_data.size(); i++) { + const float expected_output = get_act_output(op_name, input_data[i]); + EXPECT_FLOAT_EQ(output_data[i], expected_output); + } + } +} + +TEST_F(OpConverterTest, ConvertExpandDims) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Two inputs expected for ExpandDims, at my_expanddims"); + } + + // Get the NodeDef for ExpandDims. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); + auto expanddims = + ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); + const NodeDef& node_def = expanddims.operation.node()->def(); + + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ExpandDims expects tensor for input, at my_expanddims"); + } + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "ExpandDims expects weights for axis, at my_expanddims"); + } + { + // Add dim at batch dimension, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {1}, {0}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Add dim at batch dimension via negative axis, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-5}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Modifying batch dimension is not supported for ExpandDims, at " + "my_expanddims"); + } + { + // Axis > rank(input), should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {5}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + { + // Axis < -rank(input)-1, should fail. + Reset(); + AddTestTensor("input", {1, 2, 3}); + // Input is rank 4 (batch dim included) + AddTestWeights("weights", {1}, {-6}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for ExpandDims is invalid, must be in the range " + "[-rank(input) - 1, rank(input)], at my_expanddims"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, int axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + int axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kExpandDimsOKCases = 8; + TestParams ok_params[kExpandDimsOKCases] = { + TestParams{{2, 3}, 1, {1, 2, 3}}, TestParams{{2, 3}, -3, {1, 2, 3}}, + TestParams{{2, 3}, 3, {2, 3, 1}}, TestParams{{2, 3}, -1, {2, 3, 1}}, + TestParams{{2, 3}, 2, {2, 1, 3}}, TestParams{{2, 3}, -2, {2, 1, 3}}, + TestParams{{6}, 1, {1, 6}}, TestParams{{6}, -1, {6, 1}}, + }; + for (int i = 0; i < kExpandDimsOKCases; ++i) { + Reset(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", {1}, {ok_params[i].axis}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_expanddims", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + +TEST_F(OpConverterTest, ConvertSqueeze) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "One input expected for Squeeze, at my_squeeze"); + } + { + // No attrs, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + const NodeDef& node_def = squeeze.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze is only implemented for explicit dims, at my_squeeze"); + } + + // Get the NodeDef for Squeeze. + auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axis); + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Squeeze expects tensor for input, at my_squeeze"); + } + { + // Squeeze batch dim, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({0}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze batch dim via negative axis, should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Cannot squeeze batch dimension, at my_squeeze"); + } + { + // Squeeze >= rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({4}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + { + // Squeeze < -rank(input), should fail. + Reset(); + NodeDef node_def = get_squeeze_nodedef({-5}); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Axis for Squeeze is invalid, must be in the range " + "[-rank(input), rank(input)), at my_squeeze"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, const std::vector& axis, + const std::vector& expected_output_dims) + : input_dims(input_dims), + axis(axis), + expected_output_dims(expected_output_dims) {} + std::vector input_dims; + std::vector axis; + std::vector expected_output_dims; + }; + + // Ok. + const int kSqueezeOKCases = 10; + TestParams ok_params[kSqueezeOKCases] = { + TestParams{{1, 2, 3}, {1}, {2, 3}}, + TestParams{{1, 2, 3}, {-3}, {2, 3}}, + TestParams{{2, 3, 1}, {3}, {2, 3}}, + TestParams{{2, 3, 1}, {-1}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, + TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, + TestParams{{1, 6}, {1}, {6}}, + TestParams{{6, 1}, {2}, {6}}, + }; + for (int i = 0; i < kSqueezeOKCases; ++i) { + Reset(); + NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); + AddTestTensor("input", ok_params[i].input_dims); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + std::vector output_data(6); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze", + &output_data); + EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index b30d94b02824516906ea8880ac6de0bbee9e166c..c1688d4db88a270dcd202989f89a677ed10576d9 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -67,6 +67,9 @@ tensorflow::Status TRTOptimizationPass::Init( TF_RETURN_IF_ERROR(GetPrecisionMode( Uppercase(params.at("precision_mode").s()), &precision_mode_)); } + if (params.count("use_calibration")) { + use_calibration_ = params.at("use_calibration").b(); + } return tensorflow::Status::OK(); } @@ -187,8 +190,8 @@ tensorflow::Status TRTOptimizationPass::Optimize( *optimized_graph = item.graph; return tensorflow::Status::OK(); } - if (VLOG_IS_ON(2)) { - VLOG(2) << CurrentStackTrace(); + if (VLOG_IS_ON(3)) { + LOG(INFO) << CurrentStackTrace(); PrintDebugInfo(cluster, item); } int max_dim = -1; @@ -222,6 +225,12 @@ tensorflow::Status TRTOptimizationPass::Optimize( TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; + if (use_calibration_ && precision_mode_ != INT8MODE) { + LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. " + << "Falling back to use_calibration = False."; + use_calibration_ = false; + } + std::vector nodes_to_preserve; for (const auto& n : item.NodesToPreserve()) { auto tokens = str_util::Split(n, ":"); @@ -250,6 +259,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( cp.is_dyn_op = is_dynamic_op_; cp.cached_engine_batches = batches_; cp.max_cached_engines = max_cached_batches_; + cp.use_calibration = use_calibration_; auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index 71b51d13681cb3f75dad034f3fb0f73dea2bacc1..3e8dc0978e43e2e9ba07aaa09f74acfe8e59b9a7 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -38,7 +38,8 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { maximum_batch_size_(-1), is_dynamic_op_(false), max_cached_batches_(1), - max_workspace_size_bytes_(256LL << 20) { + max_workspace_size_bytes_(256LL << 20), + use_calibration_(true) { VLOG(1) << "Constructing " << name_; } @@ -67,6 +68,7 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { std::vector batches_; int max_cached_batches_; int64_t max_workspace_size_bytes_; + bool use_calibration_; }; } // namespace convert diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 019446813a56de6316a04c1738ae13d03e8f4713..bad568644bb1f8d01d4cb0a7c853ec47d6f19e45 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -124,8 +124,10 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); - calibration_mode_ = - (precision_mode_ == INT8MODE && calibration_data.size() == 0); + OP_REQUIRES_OK(context, + context->GetAttr("use_calibration", &use_calibration_)); + calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && + calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -149,9 +151,6 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { - if (!calibration_mode_) { - VLOG(1) << "Executing native engine"; - } std::vector inputs; std::vector* outputs = new std::vector(); if (native_func_ == tensorflow::kInvalidHandle) { @@ -172,7 +171,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, inputs.push_back(ctx->input(i)); } helper->Ref(); // Increment count for calculating native graph - VLOG(1) << "Executing native segment " << name(); + VLOG(1) << "Executing native segment: " << name(); lib->Run(opts, native_func_, inputs, outputs, [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); @@ -192,6 +191,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper) { + VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); tensorflow::core::ScopedUnref sc(helper); // TODO(aaroey): remove the ResourceMgr singleton. @@ -303,12 +303,13 @@ bool TRTEngineOp::ExecuteTrtEngine( OpKernelContext* ctx, const int num_batch, nvinfer1::ICudaEngine* trt_engine_ptr, nvinfer1::IExecutionContext* trt_execution_context_ptr) { + VLOG(1) << "Executing TRT engine: " << name(); const bool kRetry = true; const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { const string input_name = StrCat(kInputPHName, i); - const size_t binding_index = + const int binding_index = trt_engine_ptr->getBindingIndex(input_name.c_str()); if (binding_index == -1) { LOG(ERROR) << "Input node not found, at " << input_name; @@ -345,7 +346,7 @@ bool TRTEngineOp::ExecuteTrtEngine( for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const size_t binding_index = + const int binding_index = trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; @@ -491,13 +492,14 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, } TrtUniquePtrType engine; bool convert_successfully = false; - VLOG(0) << name() << " Constructing a new engine with batch size " - << batch_size; + LOG(INFO) << "Building a new TensorRT engine for " << name() + << " with batch size " << batch_size; // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + &logger, allocator, calibrator_.get(), &engine, use_calibration_, + &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built @@ -567,8 +569,8 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( const int64 workspace_size_bytes = workspace_size_; cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, platform_gpu_id, workspace_size_bytes]() { - VLOG(0) << "Starting calibration thread on device " << platform_gpu_id - << ", Calibration Resource @ " << cres; + LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id + << ", Calibration Resource @ " << cres; auto err = cudaSetDevice(platform_gpu_id); if (err != cudaSuccess) { // TODO(aaroey): should return error here. @@ -586,6 +588,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), &cres->engine_, + /*use_calibration=*/true, /*convert_successfully=*/nullptr); if (!s.ok()) { LOG(ERROR) << "Calibration failed: " << s; diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 8fe06758914261035c90a6fda3f114a63a8ac93a..b545f497f32d5a1a6960b748467ca189b7debf6c 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -130,6 +130,10 @@ class TRTEngineOp : public AsyncOpKernel { // The finalized calibrator for inference. std::unique_ptr calibrator_; + + // If true, create calibration graph for INT8 mode. Otherwise, we are using + // user-provided quantization ranges. + bool use_calibration_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index e0c7b6272379a20e3dacb6cd7c3b39de735d844d..92405906eb76b043bc08b68e25e16ab40197dddf 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -16,6 +16,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" @@ -39,18 +40,19 @@ REGISTER_OP("TRTEngineOp") .Attr("cached_engine_batches: list(int) = []") .Attr("max_cached_engines_count: int = 1") .Attr("workspace_size_bytes: int") - .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}") .Attr("calibration_data: string = ''") + .Attr("use_calibration: bool = true") .Input("in_tensor: InT") - .Output("out_tensor: OutT"); -// TODO(jie): TF requires concrete output shape for concrete input shapes. -// This is tricky for batch dimension, since we cannot ensure which input -// would carry the correct batch dimension (for the current stage of the -// implementation, we do require all input tensor to carry the same batch -// size, but this could change in the future). Hence we disable shape -// inference function as a workaround. -// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); - + .Output("out_tensor: OutT") + // TODO(jie): TF requires concrete output shape for concrete input shapes. + // This is tricky for batch dimension, since we cannot ensure which input + // would carry the correct batch dimension (for the current stage of the + // implementation, we do require all input tensor to carry the same batch + // size, but this could change in the future). Hence we disable shape + // inference function as a workaround. + // .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index bb81fbf93f37b97d01bb1e10fefb8c7da64b329f..203b2697babe32b45523109708cbf062dceee33b 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -63,19 +63,20 @@ class TrtPrecisionMode(object): return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] -def tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None): +def get_tensorrt_rewriter_config(rewriter_config=None, + max_batch_size=1, + max_workspace_size_bytes=2 << 20, + precision_mode=TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batch_sizes=None, + use_calibration=True): """Returns a RewriterConfig proto for TRT transformation. Args: - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. + rewriter_config: a template RewriterConfig proto used to create a + TRT-enabled RewriterConfig. If None, it will use a default one. max_batch_size: max size for the input batch max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the 'workspaceSize' @@ -95,6 +96,15 @@ def tensorrt_rewriter_config(rewriter_config=None, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. + use_calibration: this argument is ignored if precision_mode is not INT8. If + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. Returns: A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. @@ -107,13 +117,16 @@ def tensorrt_rewriter_config(rewriter_config=None, rewriter_config, rewriter_config_pb2.RewriterConfig): raise TypeError("rewriter_config should be a RewriterConfig proto.") + rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() if rewriter_config is None: - rewriter_config = rewriter_config_pb2.RewriterConfig() # Layout optimizer may add Const nodes followed by Reshape nodes, thus we # need to run constant folding again. - rewriter_config.optimizers.extend(["constfold", "layout", "constfold"]) - rewriter_config.meta_optimizer_iterations = ( + rewriter_config_with_trt.optimizers.extend( + ["constfold", "layout", "constfold"]) + rewriter_config_with_trt.meta_optimizer_iterations = ( rewriter_config_pb2.RewriterConfig.ONE) + else: + rewriter_config_with_trt.CopyFrom(rewriter_config) if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): raise ValueError(("precision mode '{}' is not supported." @@ -121,7 +134,7 @@ def tensorrt_rewriter_config(rewriter_config=None, precision_mode, TrtPrecisionMode.supported_precision_modes)) - optimizer = rewriter_config.custom_optimizers.add() + optimizer = rewriter_config_with_trt.custom_optimizers.add() optimizer.name = "TensorRTOptimizer" optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size optimizer.parameter_map["max_batch_size"].i = max_batch_size @@ -138,7 +151,8 @@ def tensorrt_rewriter_config(rewriter_config=None, "maximum_cached_engines items.") optimizer.parameter_map["cached_engine_batches"].list.i.extend( cached_engine_batch_sizes) - return rewriter_config + optimizer.parameter_map["use_calibration"].b = use_calibration + return rewriter_config_with_trt def create_inference_graph(input_graph_def, @@ -150,7 +164,7 @@ def create_inference_graph(input_graph_def, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batch_sizes=None, - rewriter_config=None, + use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, output_saved_model_dir=None, @@ -182,8 +196,15 @@ def create_inference_graph(input_graph_def, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. + use_calibration: this argument is ignored if precision_mode is not INT8. If + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. input_saved_model_dir: the directory to load the SavedModel which contains the input graph to transforms. Used only when input_graph_def is None. input_saved_model_tags: list of tags to load the SavedModel. @@ -191,8 +212,9 @@ def create_inference_graph(input_graph_def, returned GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. If not specified, - a default ConfigProto will be used. + session_config: the ConfigProto used to create a Session. It's also used as + a template to create a TRT-enabled ConfigProto for conversion. If not + specified, a default ConfigProto will be used. Returns: A GraphDef transformed from input_graph_def (or the SavedModel graph def @@ -322,21 +344,30 @@ def create_inference_graph(input_graph_def, grappler_meta_graph_def.collection_def["train_op"].CopyFrom( output_collection) - # Create RewriterConfig. - rewriter_config = tensorrt_rewriter_config( + # Create TRT-enabled ConfigProto. + session_config_with_trt = config_pb2.ConfigProto() + session_config_with_trt.CopyFrom(session_config) + rewriter_config = None + if (session_config_with_trt.HasField("graph_options") and + session_config_with_trt.graph_options.HasField("rewrite_options")): + rewriter_config = session_config_with_trt.graph_options.rewrite_options + rewriter_config_with_trt = get_tensorrt_rewriter_config( rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, - cached_engine_batch_sizes) + cached_engine_batch_sizes, use_calibration) + session_config_with_trt.graph_options.rewrite_options.CopyFrom( + rewriter_config_with_trt) # Run Grappler. transformed_graph_def = tf_optimizer.OptimizeGraph( - rewriter_config, grappler_meta_graph_def, graph_id=b"tf_graph") + session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") # Optionally write the transformed graphdef as SavedModel. if output_saved_model_dir is not None: saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) with ops.Graph().as_default(): importer.import_graph_def(transformed_graph_def, name="") + # We don't use TRT here. with session.Session(config=session_config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py index 9f2eeac990dcacb547d336b68bc042016c3e6171..a7b2d2ea50543ba85c5a13dd6ca320e794ca47f1 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -47,9 +47,9 @@ from tensorflow.python.tools import saved_model_utils class TrtConvertTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration python API.""" - def testTensorrtRewriterConfig(self): - """Test case for trt_convert.tensorrt_rewriter_config().""" - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + def testGetTensorrtRewriterConfig(self): + """Test case for trt_convert.get_tensorrt_rewriter_config().""" + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( rewriter_config=None, max_batch_size=128, max_workspace_size_bytes=1234, @@ -162,7 +162,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase): node_name_to_op = {node.name: node.op for node in graph_def.node} self.assertEqual({ "input": "Placeholder", - "my_trt_op_0": "TRTEngineOp", + "TRTEngineOp_0": "TRTEngineOp", "output": "Identity" }, node_name_to_op) @@ -188,11 +188,12 @@ class TrtConvertTest(test_util.TensorFlowTestCase): self.assertAllEqual([[[4.0]]] * batch_size, result) execute_engine_test_value = ("done" if expect_engine_is_run else "") execute_native_segment_test_value = ("" if expect_engine_is_run else "done") - self.assertEqual(execute_engine_test_value, - trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine")) + self.assertEqual( + execute_engine_test_value, + trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine")) self.assertEqual( execute_native_segment_test_value, - trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment")) + trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment")) def testCreateInferenceGraph_MinimumSegmentSize(self): if not trt_convert.is_tensorrt_enabled(): diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 840da6e78d88392b3c1ef5c9f6e31a2f355d09f1..aac9e5c7bd725fc10bcaa04536ebc7be071b4d4c 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -39,7 +39,8 @@ namespace tensorrt { class TRTCalibrationResource : public tensorflow::ResourceBase { public: ~TRTCalibrationResource() { - VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + LOG(INFO) << "Destroying Calibration Resource " << std::endl + << DebugString(); builder_.reset(); engine_.reset(); // We need to manually destroy the builder and engine before the allocator diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 4f64b7a9522a177624baeb425ed643c5bff7e65f..6abc5226ccf96e472df77269bee6186726e5768d 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -33,6 +33,7 @@ namespace tensorflow { namespace tensorrt { namespace segment { using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing @@ -406,22 +407,42 @@ tensorflow::Status SegmentGraph( // Use a union-find to collect the nodes that belong to the same // segment. A node value of nullptr indicates that the node is not a candidate // for TRT. + std::unordered_set unsupported_ops; + int num_unsupported_ops = 0; std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); if (options.exclude_node_list.count(node->name()) != 0) { - VLOG(1) << "Not a TF-TRT candidate: " << node->name() - << " (excluded by segmenter option)."; + VLOG(1) << "Not a TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: excluded by segmenter option)"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; node = nullptr; } else { const Status status = candidate_fn(node->tf_node()); if (!status.ok()) { - VLOG(1) << "Not a TF-TRT candidate: " << node->name() << ": " << status; + VLOG(1) << "Not a TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << "), " + << "(Reason: " << status << ")"; + unsupported_ops.emplace(node->tf_node()->type_string()); + num_unsupported_ops++; node = nullptr; } } node_segments.emplace_back(node); } + string msg = StrCat( + "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(), + " different types in the graph that", " are not converted to TensorRT: "); + for (const auto& elem : unsupported_ops) { + StrAppend(&msg, elem, ", "); + } + LOG(INFO) << msg << "(For more information see " + << "https://docs.nvidia.com/deeplearning" + << "/dgx/integrate-tf-trt/index.html#support-ops)."; // The segmentation algorithm below visits nodes in reverse topological order // and attempts to merge nodes along output edges. That means that subgraphs diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index 18096e0ff1ec6b9872346d8a84ac93c542cfb643..ff317e43e1e6ff1c0b869ae8dc6d1fda8f0ce126 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -56,8 +56,9 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): strides=[1, 2, 2, 1], padding="SAME", name="conv") - bias = constant_op.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + bias = constant_op.constant([4., 1.5, 2., 3., 5., 7.], + name="bias", + dtype=dtype) added = nn.bias_add(conv, bias, name="bias_add") relu = nn.relu(added, "relu") identity = array_ops.identity(relu, "identity") @@ -73,11 +74,12 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", - # "relu", "identity", "max_pool"] - return ["my_trt_op_0"] + return { + "TRTEngineOp_0": [ + "weights", "conv", "bias", "bias_add", "relu", "identity", + "max_pool" + ] + } class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -92,7 +94,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): g = ops.Graph() with g.as_default(): inp = array_ops.placeholder( - dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + dtype=dtype, shape=input_dims, name=input_name) with g.device("/GPU:0"): conv_filter = constant_op.constant( [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], @@ -105,10 +107,10 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + np.random.randn(12, 12, 6), dtype=dtype, name="c1") p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + np.random.randn(12, 12, 6), dtype=dtype, name="c2") q = math_ops.div(conv, c2, name="div") edge = self.trt_incompatible_op(q, name="incompatible") @@ -129,22 +131,21 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", - # "add", "sub1"]; - # - my_trt_op_1 should have ["weights","conv", "div"] - return ["my_trt_op_0", "my_trt_op_1"] + return { + "TRTEngineOp_0": [ + "add", "add1", "c1", "div1", "mul", "mul1", "sub", "sub1" + ], + "TRTEngineOp_1": ["c2", "conv", "div", "weights"] + } - def ShouldRunTest(self, run_params): - # TODO(aaroey): LayoutOptimizer adds Transpose(Const, Const) to the graph - # which breaks the conversion. We should fix it as: - # - Detect the invalid NodeDef earlier before adding them to segment - # - Let it able to change the RewriterConfig when calling - # create_inference_graph(). - # It will be good to add debugging feature for Grappler to print the graph - # after running each optimizer. - return False + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + return super( + SimpleMultiEnginesTest, self + ).GetConversionParams(run_params)._replace( + # Disable layout optimizer, since it'll add Transpose(Const, Const) to + # the graph and breaks the conversion check. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): @@ -153,7 +154,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): """Setup method.""" super(PartiallyConvertedTestA, self).setUp() # Let it fail to build the second engine. - trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail") def GetParams(self): """Create a graph containing two segment.""" @@ -190,14 +191,16 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" return { # Only the first engine is built. - "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + "TRTEngineOp_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] } def ShouldRunTest(self, run_params): """Whether to run the test.""" # Disable the test in fp16 mode since multiple matmul and add ops together # can cause overflow. - return run_params.precision_mode != "FP16" + return ((run_params.precision_mode != "FP16") and + not (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_calibration)) class PartiallyConvertedTestB(PartiallyConvertedTestA): @@ -207,13 +210,13 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA): super(PartiallyConvertedTestB, self).setUp() # Let it fail to build the first engine. trt_convert.clear_test_values("") - trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail") def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { # Only the second engine is built. - "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + "TRTEngineOp_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] } @@ -257,8 +260,8 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["add", "add1", "mul"], - "my_trt_op_1": ["add2", "add3", "mul1"] + "TRTEngineOp_0": ["add", "add1", "mul"], + "TRTEngineOp_1": ["add2", "add3", "mul1"] } @@ -289,7 +292,7 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return {"my_trt_op_0": ["c", "add", "add1", "mul"]} + return {"TRTEngineOp_0": ["c", "add", "add1", "mul"]} class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -324,12 +327,12 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["add2", "add3", "mul1"], + "TRTEngineOp_0": ["add2", "add3", "mul1"], # Why segment ["add", "add1", "mul"] was assigned segment id 1 # instead of 0: the parent node of this segment is actually const # node 'c', but it's removed later since it's const output of the # segment which is not allowed. - "my_trt_op_1": ["add", "add1", "mul"] + "TRTEngineOp_1": ["add", "add1", "mul"] } @@ -373,8 +376,8 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["c1", "add", "add1", "mul"], - "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + "TRTEngineOp_0": ["c1", "add", "add1", "mul"], + "TRTEngineOp_1": ["c2", "add2", "add3", "mul1"] } diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 4b8880817876143dc753cfacdb79d4ad50347fe0..f42308ecb7c8f8a107e78008abd3f470ddc85975 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -79,12 +79,12 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): """Return the expected engines to build.""" if (run_params.dynamic_engine and not trt_test.IsQuantizationMode(run_params.precision_mode)): - return ["my_trt_op_0", "my_trt_op_1"] - return ["my_trt_op_1"] + return ["TRTEngineOp_0", "TRTEngineOp_1"] + return ["TRTEngineOp_1"] def ExpectedEnginesToRun(self, run_params): """Return the expected engines to run.""" - return ["my_trt_op_1"] + return ["TRTEngineOp_1"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7545bb9df20f295a8fdbc82b573cdb3407f8c5e4..053b38ff1c0578c58f39dd6dc0630d1401a105af 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -41,6 +41,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): input_name = "input" input_matrix_rows = 4 input_matrix_columns = 144 + # Note that tf.nn.bias_add supports up to 5 dimensions. input_dims = [input_matrix_rows, input_matrix_columns] output_name = "output" g = ops.Graph() @@ -74,18 +75,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x5 = nn.bias_add(x5, b) x5 = gen_array_ops.reshape(x5, [4, -1]) - x6 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x6 = gen_array_ops.reshape(x, [4, 24, 6]) + b = self._ConstOp((6,)) x6 = nn.bias_add(x6, b, data_format="NHWC") x6 = gen_array_ops.reshape(x6, [4, -1]) - x7 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((4,)) + x7 = gen_array_ops.reshape(x, [4, 12, 4, 3]) + b = self._ConstOp((3,)) x7 = nn.bias_add(x7, b, data_format="NHWC") x7 = gen_array_ops.reshape(x7, [4, -1]) - x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) - b = self._ConstOp((2,)) + x8 = gen_array_ops.reshape(x, [4, 4, 3, 2, 6]) + b = self._ConstOp((6,)) x8 = nn.bias_add(x8, b, data_format="NHWC") x8 = gen_array_ops.reshape(x8, [4, -1]) @@ -94,13 +95,13 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x9 = nn.bias_add(x9, b, data_format="NCHW") x9 = gen_array_ops.reshape(x9, [4, -1]) - x10 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((12,)) + x10 = gen_array_ops.reshape(x, [4, 3, 4, 12]) + b = self._ConstOp((3,)) x10 = nn.bias_add(x10, b, data_format="NCHW") x10 = gen_array_ops.reshape(x10, [4, -1]) - x11 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x11 = gen_array_ops.reshape(x, [4, 6, 24]) + b = self._ConstOp((6,)) x11 = nn.bias_add(x11, b, data_format="NCHW") x11 = gen_array_ops.reshape(x11, [4, -1]) @@ -116,13 +117,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" - return super(BiasaddMatMulTest, - self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=1) + conversion_params = super(BiasaddMatMulTest, + self).GetConversionParams(run_params) + return conversion_params._replace( + max_batch_size=4, + maximum_cached_engines=1, + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] def ShouldRunTest(self, run_params): """Whether to run the test.""" diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index b53cb3c091ea477ef0974d9d14d82c587a431152..169835956c046dd675e967daa05fd81405662e38 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -56,10 +55,10 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): ]: a = self._ConstOp(weights_shape) f = x + a - x = math_ops.sigmoid(f) + x = self.trt_incompatible_op(f) a = self._ConstOp(weights_shape) f = a + x - x = math_ops.sigmoid(f) + x = self.trt_incompatible_op(f) gen_array_ops.reshape(x, [5, -1], name=output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), @@ -70,7 +69,7 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_%d" % i for i in range(16)] + return ["TRTEngineOp_%d" % i for i in range(16)] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 465cb022964df046bf03a481bb1c6b65750aa883..c3576f81d97afe7e0e42cd10413971911e97774c 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -79,7 +79,7 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index e32f0478661caaab5386339c819b524656baf066..c1c883312d867b60b88ac14318041f9750ca41e6 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -64,7 +64,7 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ['my_trt_op_0'] + return ['TRTEngineOp_0'] def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index bc7c90081ff38a832b523948db10c02de7acefc2..104bac43a0b1166dcddee9920991582f33e93316 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -68,7 +68,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index 11be4feaf7bf8ce6c8bd16f1546dc17450c342f1..293f93d8a78bc8ab06002d6fc01cb8d6a0738698 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -60,14 +58,14 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): b = constant_op.constant( np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) q = conv - b - edge = math_ops.sigmoid(q) + edge = self.trt_incompatible_op(q) b = constant_op.constant( np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) d = b + conv - edge3 = math_ops.sigmoid(d) + edge3 = self.trt_incompatible_op(d) - edge1 = gen_math_ops.tan(conv) + edge1 = self.trt_incompatible_op(conv) t = t - edge1 q = q + edge t = t + q @@ -83,7 +81,7 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0", "my_trt_op_1"] + return ["TRTEngineOp_0", "TRTEngineOp_1"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index eddeafa38bc71743ac6c9d8e5e8db76f28ca7bf4..3e1e4b088ba200db2184dd64092cbc642a17cb3a 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -66,8 +66,8 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["bias", "mul", "sub"], - "my_trt_op_1": ["weights", "conv"] + "TRTEngineOp_0": ["bias", "mul", "sub"], + "TRTEngineOp_1": ["weights", "conv"] } diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py new file mode 100644 index 0000000000000000000000000000000000000000..31cbef89e23949ba5ceaab34e0f683fd906bf0ce --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -0,0 +1,290 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TRT INT8 conversion without calibration on Mnist model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import data +from tensorflow.python import keras +from tensorflow.python.estimator.estimator import Estimator +from tensorflow.python.estimator.model_fn import EstimatorSpec +from tensorflow.python.estimator.model_fn import ModeKeys +from tensorflow.python.estimator.run_config import RunConfig +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras.datasets import mnist +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import saver +from tensorflow.python.training.adam import AdamOptimizer +from tensorflow.python.training.checkpoint_management import latest_checkpoint +from tensorflow.python.training.training_util import get_global_step + +INPUT_NODE_NAME = 'input' +OUTPUT_NODE_NAME = 'output' + + +class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): + + def _BuildGraph(self, x): + + def _Quantize(x, r): + x = gen_array_ops.quantize_and_dequantize_v2(x, -r, r) + return x + + def _DenseLayer(x, num_inputs, num_outputs, quantization_range, name): + """Dense layer with quantized outputs. + + Args: + x: input to the dense layer + num_inputs: number of input columns of x + num_outputs: number of output columns + quantization_range: the min/max range for quantization + name: name of the variable scope + + Returns: + The output of the layer. + """ + with variable_scope.variable_scope(name): + kernel = variable_scope.get_variable( + 'kernel', + shape=[num_inputs, num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.glorot_uniform()) + bias = variable_scope.get_variable( + 'bias', + shape=[num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.zeros()) + x = math_ops.matmul(x, kernel) + x = _Quantize(x, quantization_range) + x = nn.bias_add(x, bias) + x = _Quantize(x, quantization_range) + return x + + x = _Quantize(x, 1) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=32, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=64, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Reduce + x = math_ops.reduce_mean(x, [1, 2]) + x = _Quantize(x, 6) + # FC1 + x = _DenseLayer(x, 64, 512, 6, name='dense') + x = nn.relu6(x) + # FC2 + x = _DenseLayer(x, 512, 10, 25, name='dense_1') + x = array_ops.identity(x, name=OUTPUT_NODE_NAME) + return x + + def _GetGraphDef(self, use_trt, max_batch_size, model_dir): + """Get the frozen mnist GraphDef. + + Args: + use_trt: whether use TF-TRT to convert the graph. + max_batch_size: the max batch size to apply during TF-TRT conversion. + model_dir: the model directory to load the checkpoints. + + Returns: + The frozen mnist GraphDef. + """ + graph = ops.Graph() + with self.session(graph=graph) as sess: + with graph.device('/GPU:0'): + x = array_ops.placeholder( + shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME) + self._BuildGraph(x) + # Load weights + mnist_saver = saver.Saver() + checkpoint_file = latest_checkpoint(model_dir) + mnist_saver.restore(sess, checkpoint_file) + # Freeze + graph_def = graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME]) + # Convert with TF-TRT + if use_trt: + logging.info('Number of nodes before TF-TRT conversion: %d', + len(graph_def.node)) + graph_def = trt_convert.create_inference_graph( + graph_def, + outputs=[OUTPUT_NODE_NAME], + max_batch_size=max_batch_size, + precision_mode='INT8', + max_workspace_size_bytes=4096 << 19, + minimum_segment_size=2, + use_calibration=False, + ) + logging.info('Number of nodes after TF-TRT conversion: %d', + len(graph_def.node)) + num_engines = len( + [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp']) + self.assertEqual(1, num_engines) + return graph_def + + def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir): + """Train or evaluate the model. + + Args: + is_training: whether to train or evaluate the model. In training mode, + quantization will be simulated where the quantize_and_dequantize_v2 are + placed. + use_trt: if true, use TRT INT8 mode for evaluation, which will perform + real quantization. Otherwise use native TensorFlow which will perform + simulated quantization. Ignored if is_training is True. + batch_size: batch size. + num_epochs: how many epochs to train. Ignored if is_training is False. + model_dir: where to save or load checkpoint. + + Returns: + The Estimator evaluation result. + """ + # Get dataset + train_data, test_data = mnist.load_data() + + def _PreprocessFn(x, y): + x = math_ops.cast(x, dtypes.float32) + x = array_ops.expand_dims(x, axis=2) + x = 2.0 * (x / 255.0) - 1.0 + y = math_ops.cast(y, dtypes.int32) + return x, y + + def _EvalInputFn(): + mnist_x, mnist_y = test_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=1) + iterator = data.make_one_shot_iterator(dataset) + features, labels = iterator.get_next() + return features, labels + + def _TrainInputFn(): + mnist_x, mnist_y = train_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.shuffle(2 * len(mnist_x)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=num_epochs) + iterator = data.make_one_shot_iterator(dataset) + features, labels = iterator.get_next() + return features, labels + + def _ModelFn(features, labels, mode): + if is_training: + logits_out = self._BuildGraph(features) + else: + graph_def = self._GetGraphDef(use_trt, batch_size, model_dir) + logits_out = importer.import_graph_def( + graph_def, + input_map={INPUT_NODE_NAME: features}, + return_elements=[OUTPUT_NODE_NAME + ':0'], + name='')[0] + + loss = losses.sparse_softmax_cross_entropy( + labels=labels, logits=logits_out) + summary.scalar('loss', loss) + + classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out') + accuracy = metrics.accuracy( + labels=labels, predictions=classes_out, name='acc_op') + summary.scalar('accuracy', accuracy[1]) + + if mode == ModeKeys.EVAL: + return EstimatorSpec( + mode, loss=loss, eval_metric_ops={'accuracy': accuracy}) + elif mode == ModeKeys.TRAIN: + optimizer = AdamOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss, global_step=get_global_step()) + return EstimatorSpec(mode, loss=loss, train_op=train_op) + + config_proto = config_pb2.ConfigProto() + config_proto.gpu_options.allow_growth = True + estimator = Estimator( + model_fn=_ModelFn, + model_dir=model_dir if is_training else None, + config=RunConfig(session_config=config_proto)) + + if is_training: + estimator.train(_TrainInputFn) + results = estimator.evaluate(_EvalInputFn) + logging.info('accuracy: %s', str(results['accuracy'])) + return results + + # To generate the checkpoint, set a different model_dir and call self._Run() + # by setting is_training=True and num_epochs=1000, e.g.: + # model_dir = '/tmp/quantization_mnist' + # self._Run( + # is_training=True, + # use_trt=False, + # batch_size=128, + # num_epochs=100, + # model_dir=model_dir) + def testEval(self): + if not trt_convert.is_tensorrt_enabled(): + return + model_dir = test.test_src_dir_path('contrib/tensorrt/test/testdata') + + accuracy_tf_native = self._Run( + is_training=False, + use_trt=False, + batch_size=128, + num_epochs=None, + model_dir=model_dir)['accuracy'] + logging.info('accuracy_tf_native: %f', accuracy_tf_native) + self.assertAllClose(accuracy_tf_native, 0.9662) + + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return + + accuracy_tf_trt = self._Run( + is_training=False, + use_trt=True, + batch_size=128, + num_epochs=None, + model_dir=model_dir)['accuracy'] + logging.info('accuracy_tf_trt: %f', accuracy_tf_trt) + self.assertAllClose(accuracy_tf_trt, 0.9677) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/contrib/tensorrt/test/quantization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e425a3674635650d7292ab072178e98932e6b824 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/quantization_test.py @@ -0,0 +1,144 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.python import trt_convert +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def _GetParams(add_quantization_nodes, dtype=dtypes.float32): + input_name = "input" + input_dims = [8, 8] + output_name = "output" + + def _Quantize(x, r): + if add_quantization_nodes: + x = gen_array_ops.fake_quant_with_min_max_vars(x, -r, r) + return x + + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + x = _Quantize(x, 10.0) + x = x + 5 + x = _Quantize(x, 15.0) + x = x - 5 + x = _Quantize(x, 10.0) + x = x * 0.1 + x = _Quantize(x, 1.0) + w = constant_op.constant(np.ones((8, 1)), dtype=dtypes.float32) + x = math_ops.matmul(x, w) + x = _Quantize(x, 10.0) + x = array_ops.identity(x, name=output_name) + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + output_names=[output_name], + expected_output_dims=[(8, 1)]) + + +class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=False) + + def ShouldRunTest(self, run_params): + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Only test static engine mode, with or without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_optimizer and not run_params.dynamic_engine) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + if run_params.use_calibration: + # In static engine mode with calibration, it should build a calibration + # engine. + return ["TRTEngineOp_0"] + # In static engine mode without calibration, the engine building will fail + # since no quantization ranges are set, which results in no TRT nodes. + return [] + + +class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=True) + + def ShouldRunTest(self, run_params): + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Test static/dynamic engine with/without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_optimizer) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["TRTEngineOp_0"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + +class NonQuantizedPrecisionsWithRangesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment with no quantization ranges.""" + return _GetParams(add_quantization_nodes=True) + + def ShouldRunTest(self, run_params): + # Only test FP32/FP16 mode. + return not trt_test.IsQuantizationMode(run_params.precision_mode) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + # The fake quant ops are not supported in FP32/FP16 mode, and will split the + # graph into three TRT segments. + return ["TRTEngineOp_0", "TRTEngineOp_1", "TRTEngineOp_2", "TRTEngineOp_3"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index 74a4a059257ffde4c86df1f18b3ce35c3790ec7a..563232fc12675d9e1b32b7ab461591af57beadb9 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -51,8 +51,10 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): c = constant_op.constant(3.0, name="c%d_3" % i) q = math_ops.add(q, c, name="add%d_3" % i) if i == 0: + axis = constant_op.constant(-1, dtype=dtypes.int32, name="axis") for j in range(2): - q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) + q = array_ops.expand_dims(q, axis, name="expand%d_%d" % (i, j)) + q = self.trt_incompatible_op(q) q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) outputs.append(q) # Combine both paths @@ -68,11 +70,11 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": [ + "TRTEngineOp_0": [ "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", - "abs0_2" + "abs0_2", "expand0_0", "expand0_1", "axis" ], - "my_trt_op_1": [ + "TRTEngineOp_1": [ "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" ], diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py index bbc724ab18e18be3e831732071a31f0a541a4059..207944468ab0b038abfe01f0096d7dc220d064ed 100644 --- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py +++ b/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py @@ -79,8 +79,8 @@ class ReshapeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": ["reshape-%d" % i for i in range(7)] + - ["reshape-%d/shape" % i for i in range(7)] + "TRTEngineOp_0": ["reshape-%d" % i for i in range(7)] + + ["reshape-%d/shape" % i for i in range(7)] } def ShouldRunTest(self, run_params): @@ -117,7 +117,7 @@ class TransposeTest(trt_test.TfTrtIntegrationTestBase): # Note: by default Grappler will run the TRT optimizer twice. At the # first time it will group the two transpose ops below to same segment # then fail the conversion due to the expected batch dimension problem. - # At the second time, since the input of bridge op is my_trt_op_0, it + # At the second time, since the input of bridge op is TRTEngineOp_0, it # will fail to do shape inference which then cause conversion to fail. # TODO(laigd): support shape inference, make TRT optimizer run only # once, and fix this. @@ -136,7 +136,7 @@ class TransposeTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" return { - "my_trt_op_0": [ + "TRTEngineOp_0": [ "transpose-1", "transpose-1/perm", "transposeback", "transposeback/perm" ] diff --git a/tensorflow/contrib/tensorrt/test/testdata/checkpoint b/tensorflow/contrib/tensorrt/test/testdata/checkpoint new file mode 100644 index 0000000000000000000000000000000000000000..a603e1aec91adab04fd9801ba05a2ee9adfbb6e8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/testdata/checkpoint @@ -0,0 +1,3 @@ +model_checkpoint_path: "model.ckpt-46900" +all_model_checkpoint_paths: "model.ckpt-0" +all_model_checkpoint_paths: "model.ckpt-46900" diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..88a998f184b275121e1e76eb51d2310da149f10a Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 differ diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index new file mode 100644 index 0000000000000000000000000000000000000000..537976571337508ab1798d33646c51d62a146ecc Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index differ diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index a725d0651c92fe18bcfd284cffd40cdfec2e6c69..495a9391a1e818a6078988161c9bf72f6143737f 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -42,14 +43,15 @@ TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ "gdef", "input_names", "input_dims", "output_names", "expected_output_dims" ]) -RunParams = namedtuple( - "RunParams", - ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) +RunParams = namedtuple("RunParams", [ + "use_optimizer", "precision_mode", "dynamic_engine", "test_name", + "use_calibration" +]) ConversionParams = namedtuple("ConversionParams", [ "max_batch_size", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batch_sizes", "rewriter_config" + "cached_engine_batch_sizes", "rewriter_config", "use_calibration" ]) PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -65,6 +67,34 @@ class GraphState(object): INFERENCE = 2 +def OptimizerDisabledRewriterConfig(): + """Returns a RewriterConfig with all default Grappler optimizers disabled.""" + rewriter_config = rewriter_config_pb2.RewriterConfig() + + # Turn off all default Grappler optimizers. + off = rewriter_config_pb2.RewriterConfig.OFF + rewriter_config.layout_optimizer = off + rewriter_config.constant_folding = off + rewriter_config.shape_optimization = off + rewriter_config.remapping = off + rewriter_config.arithmetic_optimization = off + rewriter_config.dependency_optimization = off + rewriter_config.loop_optimization = off + rewriter_config.function_optimization = off + rewriter_config.debug_stripper = off + rewriter_config.disable_model_pruning = True + rewriter_config.scoped_allocator_optimization = off + rewriter_config.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) + rewriter_config.pin_to_host_optimization = off + rewriter_config.auto_parallel.enable = False + + # Run only once for each enabled optimizer. + rewriter_config.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) + return rewriter_config + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -139,11 +169,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, cached_engine_batch_sizes=None, - rewriter_config=None) + rewriter_config=None, + use_calibration=run_params.use_calibration) def ShouldRunTest(self, run_params): """Whether to run the test.""" - return True + # This setting combination requires quantization nodes to be present in + # order to build the engine. + return not (IsQuantizationMode(run_params.precision_mode) and + not run_params.use_calibration) def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True): """Verify the state of a particular engine after sess.run().""" @@ -194,34 +228,35 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _PrepareRun(self, graph_state): """Set up necessary testing environment before calling sess.run().""" # Clear test values added by TRTEngineOp. - trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") - trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") - trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration") + trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment") + + def _GetGPUOptions(self): + gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True + return gpu_options def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: conversion_params = self.GetConversionParams(run_params) - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( conversion_params.rewriter_config, conversion_params.max_batch_size, conversion_params.max_workspace_size_bytes, conversion_params.precision_mode, conversion_params.minimum_segment_size, conversion_params.is_dynamic_op, conversion_params.maximum_cached_engines, - conversion_params.cached_engine_batch_sizes) + conversion_params.cached_engine_batch_sizes, + conversion_params.use_calibration) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() - gpu_options = config_pb2.GPUOptions() - gpu_options.allow_growth = True - if trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options.per_process_gpu_memory_fraction = 0.50 - config = config_pb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) + gpu_options=self._GetGPUOptions(), graph_options=graph_options) return config def _ExpectTestValue(self, engine_name, method, expected_value): @@ -291,6 +326,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): params = self._GetParamsCached() conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) + + config_for_trt = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions()) + if conversion_params.rewriter_config is not None: + config_for_trt.graph_options.rewrite_options.CopyFrom( + conversion_params.rewriter_config) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=params.input_names + params.output_names, @@ -301,7 +341,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_op=conversion_params.is_dynamic_op, maximum_cached_engines=conversion_params.maximum_cached_engines, cached_engine_batch_sizes=conversion_params.cached_engine_batch_sizes, - rewriter_config=conversion_params.rewriter_config) + use_calibration=conversion_params.use_calibration, + session_config=config_for_trt) def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: @@ -400,10 +441,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_engine = not node.attr["static_engine"].b self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, node.name) + self.assertEqual(node.attr["use_calibration"].b, + run_params.use_calibration, node.name) has_calibration_data = len(node.attr["calibration_data"].s) if (IsQuantizationMode(run_params.precision_mode) and - graph_state == GraphState.INFERENCE): + run_params.use_calibration and graph_state == GraphState.INFERENCE): self.assertTrue(has_calibration_data, node.name) else: self.assertFalse(has_calibration_data, node.name) @@ -438,6 +481,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # types. scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 dims = params.input_dims[i] + # TODO(laigd): add debug options. E.g. we can set the input data to be + # continuous natural numbers: + # seq = np.arange(np.prod(dims)) + # seq.resize(dims) + # input_data.append(scale * seq.astype(dtype)) input_data.append((scale * np.random.random_sample(dims)).astype(dtype)) self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) @@ -449,7 +497,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): config_no_trt, GraphState.ORIGINAL) # Run calibration if necessary. - if IsQuantizationMode(run_params.precision_mode): + if (IsQuantizationMode(run_params.precision_mode) and + run_params.use_calibration): calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) @@ -519,27 +568,38 @@ def _AddTests(test_class): use_optimizer_options = [False, True] dynamic_engine_options = [False, True] - for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( - use_optimizer_options, PRECISION_MODES, dynamic_engine_options): + use_calibration_options = [False, True] + opts = itertools.product(use_optimizer_options, PRECISION_MODES, + dynamic_engine_options, use_calibration_options) + for (use_optimizer, precision_mode, dynamic_engine, use_calibration) in opts: if IsQuantizationMode(precision_mode): if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_engine: + if use_calibration and not dynamic_engine: + # Static engine with use_calibration=False will be static, so we want to + # test that. If use_calibration=True, only dynamic op is supported. # TODO(aaroey): construction of static calibration engine is not # supported yet. continue + else: + if use_calibration: + # Don't calibrate in FP32 or FP16 mode + continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") - test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" + calibration_type = "UseCalibration" if use_calibration else "NoCalibration" + test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode, + calibration_type) run_params = RunParams( use_optimizer=use_optimizer, precision_mode=precision_mode, dynamic_engine=dynamic_engine, - test_name=test_name) + test_name=test_name, + use_calibration=use_calibration) setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 8736bfb6449b3c25a411ec081ad58b1f8be84617..b6e5e32db1236684a06c2d44298b9a3d39667152 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -106,10 +106,7 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return [ - "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", - "my_trt_op_4" - ] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index b0271a04b364864b841c2ec9fe53aac74611b2c3..b29626d2c28b4def716aef9e2703b669b5e46374 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -76,7 +76,7 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index d7c165784bfe14bb5faffd266770328237a3eb80..9b0b189626050f678c71e9abbf7eb5296440d879 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -67,7 +67,7 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - return ["my_trt_op_0"] + return ["TRTEngineOp_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tfprof/README.md b/tensorflow/contrib/tfprof/README.md index b29d1acacf17b57549558be45c853566817c1729..f40e76f554e8815aac96344d8cb0b911bafdd712 100644 --- a/tensorflow/contrib/tfprof/README.md +++ b/tensorflow/contrib/tfprof/README.md @@ -1,7 +1,5 @@ # tfprof: TensorFlow Profiler and Beyond -

Please use `tf.profiler.xxx` instead of `tf.contrib.tfprof.xxx`

-

Full Document in tensorflow/core/profiler/README.md

diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index c230919168b937b26c68e141e15f0762ad70f3e6..4b90b596b28efec83aa349782c4874d79b6817c7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -104,8 +104,10 @@ py_test( srcs = [ "estimators_test.py", ], + shard_count = 3, srcs_version = "PY2AND3", tags = [ + "no_mac", "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. "notsan", # b/67865658 diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index af68aa03cf6583dc474eda6cda2e648fa1c3d08d..146ed9f27134e3e2a6c74627b6b78e53d65155f0 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -32,7 +32,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.estimator.export import export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index ffd838be40ed6267109fe36d95a681496fb2f964..7d780559f976516823611f3fe0ded056e4be088c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -30,7 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 90c7d8ac1a9c69216ece74af458cd750667f51ee..8f692d94da45bfaed6c72cf75d525346865aea34 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -38,7 +38,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 43c5267e632e464d43ffcbcf6c551ff83d3c5767..aab330643862c1ccf073d2a0e34e1c475b1ec15f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -802,7 +802,7 @@ class InputStatisticsFromMiniBatch(object): array_ops.shape(times)[1] - 1, self._dtype)) # Co-locate updates with their variables to minimize race conditions when # updating statistics. - with ops.colocate_with(auxiliary_variables.max_time_seen): + with ops.device(auxiliary_variables.max_time_seen.device): # There is a race condition if this value is being updated from multiple # workers. However, it should eventually reach the correct value if the # last chunk is presented enough times. @@ -810,16 +810,16 @@ class InputStatisticsFromMiniBatch(object): auxiliary_variables.max_time_seen, gen_math_ops.maximum(auxiliary_variables.max_time_seen, math_ops.reduce_max(times))) - with ops.colocate_with(auxiliary_variables.chunk_count): + with ops.device(auxiliary_variables.chunk_count.device): chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count, array_ops.shape( times, out_type=dtypes.int64)[0]) - with ops.colocate_with(auxiliary_variables.inter_observation_duration_sum): + with ops.device(auxiliary_variables.inter_observation_duration_sum.device): inter_observation_duration_assign = state_ops.assign_add( auxiliary_variables.inter_observation_duration_sum, math_ops.reduce_sum(batch_inter_observation_duration)) - with ops.colocate_with(auxiliary_variables.example_count): + with ops.device(auxiliary_variables.example_count.device): example_count_assign = state_ops.assign_add( auxiliary_variables.example_count, array_ops.size(times, out_type=dtypes.int64)) @@ -829,11 +829,11 @@ class InputStatisticsFromMiniBatch(object): # the series are then members of fewer chunks. For series which are much # longer than the chunk size (the usual/expected case), this effect becomes # irrelevant. - with ops.colocate_with(auxiliary_variables.overall_feature_sum): + with ops.device(auxiliary_variables.overall_feature_sum.device): overall_feature_sum_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum, math_ops.reduce_sum(values, axis=[0, 1])) - with ops.colocate_with(auxiliary_variables.overall_feature_sum_of_squares): + with ops.device(auxiliary_variables.overall_feature_sum_of_squares.device): overall_feature_sum_of_squares_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum_of_squares, math_ops.reduce_sum(values**2, axis=[0, 1])) @@ -869,7 +869,7 @@ class InputStatisticsFromMiniBatch(object): state_ops.assign(statistics.series_start_moments.mean, mean), state_ops.assign(statistics.series_start_moments.variance, variance)) - with ops.colocate_with(statistics.start_time): + with ops.device(statistics.start_time.device): series_start_update = control_flow_ops.cond( # Update moments whenever we even match the lowest time seen so far, # to ensure that series start statistics are eventually updated to diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index edd97b2a4c131dbce0a5111dbac7d40eddea2bae..a8cd4287e0003de300b7114cf3f88d21d3239e6e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -27,7 +27,7 @@ from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 3c07a74ed8af9e3ab70408f9b43cb62b6bd4c7f2..125750e7639ad40c481472a93353e6fb7055be96 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -40,7 +40,10 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["state_space_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_mac", + "no_windows", # TODO: needs investigation on Windows + ], deps = [ ":state_space_model", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index a0a9cb3f31a945a00eb3f6a5fd1402aab9a2df5f..4bf3a0463d9046eea2f60e9154fca1357e728215 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -14,6 +14,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", + "//knowledge/cerebra/sense/im2query:__subpackages__", "//learning/brain:__subpackages__", "//learning/deepmind:__subpackages__", "//medical/pathology:__subpackages__", @@ -78,6 +79,7 @@ py_library( "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:summary_ops_v2", @@ -215,7 +217,7 @@ py_library( ], deps = [ ":tpu_lib", - "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", @@ -263,7 +265,7 @@ py_library( ":tpu_py", "//tensorflow/compiler/xla/experimental/xla_sharding", "//tensorflow/compiler/xla/python_api:xla_shape", - "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index b4b06a40a2c8aaa97ff82baf93c8f2d55a587e37..ef35e84ba5205fb76e5afe77e670d87197ca8405 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -98,7 +98,7 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir, if (!status.ok()) { return errors::Internal( "Failed to convert op profile to json. Skipping... ", - string(status.error_message())); + string(status.message())); } TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json)); if (os) { diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 63641e00c5dbf4b4e635ecfea8bef98c7d0b7075..a081c4354a779d37140338793e66844c3fcf7a12 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -90,12 +90,12 @@ def main(unused_argv=None): tf_version = tf.__version__ print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu is None: + if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None - if FLAGS.service_addr is not None: - if FLAGS.tpu is not None: + if FLAGS.service_addr: + if FLAGS.tpu: tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 1cf7f9fcf67ec98feb02dd4298a36153e689f2e5..1b09ce173a64ba3f93ec019c8fd65dc4710f0fcf 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -80,6 +80,8 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): self._summary_writer = None self._global_step_tensor = None + self._last_checkpoint_step = None + def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run @@ -137,8 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): last_step = session.run(self._global_step_tensor) - # Save the last checkpoint synchronously if needed. - if last_step != self._timer.last_triggered_step(): + if self._last_checkpoint_step != last_step: self._save(session, last_step, asynchronous=False) for l in self._listeners: @@ -174,6 +175,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): logging.info("Checkpoint finished for %d into %s.", step, self._save_path) if not asynchronous: + self._last_checkpoint_step = step _save_fn() return @@ -183,6 +185,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): logging.info("Saver thread still in progress, skipping checkpoint.") return + self._last_checkpoint_step = step self._save_thread = threading.Thread(target=_save_fn) self._save_thread.start() diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index c694e9c1bca10d9930492c29dd1c3cbc7f7f5d04..8d6245390fc3fa005c92d01bc9b64ddb47583582 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -133,7 +133,7 @@ def StreamingFilesDataset(files, with ops.device('/job:%s' % file_reader_job): if isinstance(files, str): source_dataset = dataset_ops.Dataset.list_files(files) - elif isinstance(files, dataset_ops.Dataset): + elif isinstance(files, dataset_ops.DatasetV2): source_dataset = files else: raise ValueError('files was not a string or a dataset: %s' % files) @@ -156,7 +156,7 @@ def StreamingFilesDataset(files, source_dataset = source_dataset.prefetch(1) - source_iterator = source_dataset.make_one_shot_iterator() + source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) source_handle = source_iterator.string_handle() @function.Defun(dtypes.string) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index b58d05eac56f3586e183333f7c1a3867ee57456c..52d87b800401c3e584da9843916cfc7a767c082a 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -70,7 +70,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -94,7 +94,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -121,7 +121,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -154,7 +154,7 @@ class DatasetsTest(test.TestCase): os.path.join(self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() @@ -177,7 +177,7 @@ class DatasetsTest(test.TestCase): dataset = datasets.StreamingFilesDataset( dataset_ops.Dataset.range(10), filetype=gen_dataset) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) self._sess.run(iterator.initializer) get_next = iterator.get_next() diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 08f58a5f5b89f92502893e222cbca3bd07b2432b..4ce194590342555a7c4e9e119bf51e516a37a715 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -81,6 +81,7 @@ from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import models from tensorflow.python.keras import optimizers as keras_optimizers from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import training_arrays from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.layers import embeddings @@ -132,7 +133,7 @@ def _tpu_session_context(): An error occurred connecting or initializing your TPU. The session has been reset. re-run keras_to_tpu_model to create a new session. -""" + e) +""" + str(e)) def setup_tpu_session(cluster_resolver): @@ -438,7 +439,7 @@ class TPURewriteContext(object): self._default_placeholder = array_ops.placeholder self._default_name_scope = ops.name_scope - self._default_make_variable = base_layer.make_variable + self._default_make_variable = base_layer_utils.make_variable self._default_random_normal = random_ops.random_normal self._default_qr = gen_linalg_ops.qr @@ -486,14 +487,14 @@ class TPURewriteContext(object): gen_linalg_ops.qr = qr ops.name_scope = _name_scope - base_layer.make_variable = variable_scope.get_variable + base_layer_utils.make_variable = variable_scope.get_variable logging.info('Overriding default placeholder.') return def __exit__(self, exc_type, exc_val, exc_tb): array_ops.placeholder = self._default_placeholder ops.name_scope = self._default_name_scope - base_layer.make_variable = self._default_make_variable + base_layer_utils.make_variable = self._default_make_variable random_ops.random_normal = self._default_random_normal gen_linalg_ops.qr = self._default_qr @@ -728,7 +729,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() dummy_y_shape[0] *= tpu_assignment.num_towers - self._iterator = dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(dataset) K.get_session().run(self._iterator.initializer) self._get_next_ops = [] @@ -769,7 +770,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): def _verify_dataset_shape(self, dataset): """Verifies a dataset is of an appropriate shape for TPUs.""" - if not isinstance(dataset, dataset_ops.Dataset): + if not isinstance(dataset, dataset_ops.DatasetV2): raise ValueError('The function passed as the `x` parameter did not ' 'return a `tf.data.Dataset`.') if not isinstance(dataset.output_classes, tuple): @@ -1012,9 +1013,10 @@ class TPUFunction(object): optimizer=_replicated_optimizer(self._cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, - metrics=metrics_module.clone_metrics(self.model.metrics), + metrics=metrics_module.clone_metrics( + self.model._compile_metrics), weighted_metrics=metrics_module.clone_metrics( - self.model.weighted_metrics), + self.model._compile_weighted_metrics), target_tensors=tpu_targets, ) @@ -1184,12 +1186,9 @@ class TPUFunction(object): # pipelined loop. return None, None - if not isinstance(K.learning_phase(), int): + if isinstance(inputs[-1], int): # Remove the learning_phase flag at the end. We currently hard code the # learning_phase in TPUFunction. - assert isinstance(inputs[-1], int), ( - 'Expect the final element be learning_phase flag. Got {}'.format( - inputs[-1])) inputs = inputs[:-1] if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or @@ -1379,6 +1378,7 @@ class KerasTPUModel(models.Model): self.train_function = None self._fit_function = None self._eval_function = None + self._stateful_metric_functions = [] cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() @@ -1393,10 +1393,10 @@ class KerasTPUModel(models.Model): self.compile( self._cpu_model.optimizer, self._cpu_model.loss, - self._cpu_model.metrics, + self._cpu_model._compile_metrics, self._cpu_model.loss_weights, self._cpu_model.sample_weight_mode, - self._cpu_model.weighted_metrics, + self._cpu_model._compile_weighted_metrics, self._cpu_model.target_tensors, ) @@ -1466,7 +1466,7 @@ class KerasTPUModel(models.Model): assert not self._numpy_to_infeed_manager_list # Ensure empty. infeed_managers = [] # Managers to clean up at the end of the fit call. - if isinstance(x, dataset_ops.Dataset): + if isinstance(x, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1492,7 +1492,7 @@ class KerasTPUModel(models.Model): y = infeed_manager.dummy_y infeed_managers.append((x, infeed_manager)) - if isinstance(validation_data, dataset_ops.Dataset): + if isinstance(validation_data, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1551,7 +1551,7 @@ class KerasTPUModel(models.Model): with _tpu_session_context(): # Managers to clean up at the end of the evaluate call. infeed_managers = [] - if isinstance(x, dataset_ops.Dataset): + if isinstance(x, dataset_ops.DatasetV2): # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( 'Taking a Dataset directly is not yet supported. Please ' @@ -1676,14 +1676,10 @@ class KerasTPUModel(models.Model): callbacks, self, do_validation=do_validation, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, samples=num_training_samples, - validation_steps=validation_steps, verbose=verbose, count_mode=count_mode) @@ -1700,7 +1696,7 @@ class KerasTPUModel(models.Model): callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): # Reset stateful metrics - for m in self.stateful_metric_functions: + for m in self.metrics: m.reset_states() # Update callbacks callbacks.on_epoch_begin(epoch) @@ -1923,7 +1919,7 @@ class KerasTPUModel(models.Model): if validation_data: if (isinstance(validation_data, iterator_ops.Iterator) or isinstance(validation_data, iterator_ops.EagerIterator) or - isinstance(validation_data, dataset_ops.Dataset)): + isinstance(validation_data, dataset_ops.DatasetV2)): raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator ' 'for validation_data. Please instead pass a function ' 'that returns a `tf.data.Dataset`.') @@ -1998,14 +1994,14 @@ class KerasTPUModel(models.Model): self._optimizer = optimizer @property - def stateful_metric_functions(self): + def metrics(self): if self._tpu_model: - return self._tpu_model.stateful_metric_functions + return self._tpu_model.metrics return self._stateful_metric_functions - @stateful_metric_functions.setter - def stateful_metric_functions(self, stateful_metric_functions): - self._stateful_metric_functions = stateful_metric_functions + @metrics.setter + def metrics(self, metrics): + self._stateful_metric_functions = metrics def _make_train_function(self): if not self.train_function: @@ -2230,10 +2226,10 @@ def tpu_model(model, strategy=None): cpu_model.compile( _clone_optimizer(model.optimizer, optimizer_config), model.loss, - metrics_module.clone_metrics(model.metrics), + metrics_module.clone_metrics(model._compile_metrics), model.loss_weights, model.sample_weight_mode, - metrics_module.clone_metrics(model.weighted_metrics), + metrics_module.clone_metrics(model._compile_weighted_metrics), ) if model_weights: diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 28d3a938510a450ccba0d921663d848e2adec72f..8b0b240dc7302c203a22349d583323327fc4480b 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -217,6 +217,10 @@ class ReplicatedVariable(object): def get(self): return self._primary_var + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index a95275487899c4770ef99b620a7671eec2bb81eb..3e463823c820a3ef8628324f77e1a9caf8d385d5 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -43,12 +43,19 @@ class CoordinatorShutdownException(Exception): pass +def _clone_session(session, graph=None): + return session_lib.Session( + target=session.sess_str, + config=session._config, # pylint: disable=protected-access + graph=graph if graph else session.graph) + + def _make_heartbeat_op(session, device, request_ph): """Return a heartbeat op or None if heartbeats are not supported by device.""" try: # Test if we can connect in a isolated graph + session with ops.Graph().as_default(): - with session_lib.Session(target=session.sess_str) as temp_session: + with _clone_session(session) as temp_session: with ops.device(device): heartbeat_op = tpu_ops.worker_heartbeat('') options = config_pb2.RunOptions(timeout_in_ms=5000) @@ -220,6 +227,7 @@ class WatchdogManager(threading.Thread): self.ping_interval = ping_interval self.shutdown_timeout = shutdown_timeout self.daemon = True + self._config = session._config # pylint: disable=protected-access self._target = session.sess_str self._running = False self._devices = devices @@ -234,6 +242,7 @@ class WatchdogManager(threading.Thread): self._session = session_lib.Session( target=self._target, graph=self._graph, + config=self._config, ) if self._devices is None: @@ -334,8 +343,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): with self._graph.as_default(): logging.info('Installing graceful shutdown hook.') - self._session = session_lib.Session( - target=training_session.sess_str, graph=self._graph) + self._session = _clone_session(training_session, self._graph) self._workers = WorkerHeartbeatManager.from_devices( self._session, all_worker_devices(self._session)) self._heartbeat_supported = self._workers.num_workers() > 0 diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index e3e791faacb9b3c1fedbd83d3740e35351e38abb..def57da20d6018dcf27ccb7a9d04592f38ce2f7c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1001,8 +1001,8 @@ def rewrite(computation, `rewrite` is a list of tensors corresponding to the tensors from the output of `computation`. - All `Operation`s returned from `computation` will be executed when - evaluating any of the returned output tensors. + All `Operation`s constructed during `computation` will be executed when + evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. @@ -1111,7 +1111,7 @@ def validate_inference_rewrite_for_variables(graph): Raises: RuntimeError: if validation failed. """ - if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): raise RuntimeError( "No GuaranteeConst ops found in the graph after running " "tpu.rewrite_for_inference(...). Please check that you are using " diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index da6bdf67d686fba09d66386de982b57aa28d4dd4..672462447944b777375331d49727c4d5366cf295 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -41,7 +41,7 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = { class TPUContext(object): - """The context of current input_fn invocation.""" + """A context that holds the current configuration of the TPU computation.""" def __init__(self, internal_ctx, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 3fe896426a7ae5b4b15b0520522002e6fb0dc1b0..ccba8a46c7cad0337119672e02314684f4451479 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name, 'As TPU embedding is not optimized for small tables, ' 'please consider other ways for this embedding lookup.') - slicing = [num_hosts, 1] - - # TODO(shizhiw): deprecated, use tf.get_variable()? - return partitioned_variables.create_partitioned_variables( - name=name, - slicing=slicing, + return list(variable_scope.get_variable( + name, shape=(vocabulary_size, embedding_dimension), + partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), dtype=dtypes.float32, initializer=initializer, collections=collections, - trainable=False) + trainable=False)) @ops.RegisterGradient('TPUEmbeddingActivations') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 7cb8c4aa7f14636a9597ec45974ec013ef367414..96b9556e137effcaaa5916b9723142f737a6dc33 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -45,6 +45,7 @@ from tensorflow.contrib.training.python.training import hparam from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.estimator import estimator as estimator_lib @@ -298,9 +299,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) - training_hooks = list(training_hooks or []) - evaluation_hooks = list(evaluation_hooks or []) - prediction_hooks = list(prediction_hooks or []) + training_hooks = tuple(training_hooks or []) + evaluation_hooks = tuple(evaluation_hooks or []) + prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): @@ -335,7 +336,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - hooks = list(hooks or []) + hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -412,12 +413,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=True, - rendezvous=None): + rendezvous=None, + master=None, + session_config=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops self._rendezvous = rendezvous - + self._master = master + self._session_config = session_config self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) @@ -429,11 +433,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_ops = [] if self._should_initialize_tpu: - self._init_ops = [tpu.initialize_system(job=self._master_job)] self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] else: - self._init_ops = [] self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() @@ -475,11 +478,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return _OpQueueContext(name=name, target=target, args=args) def after_create_session(self, session, coord): - logging.info('Init TPU system') - start = time.time() + if self._should_initialize_tpu: + logging.info('Init TPU system') + start = time.time() + with ops.Graph().as_default(): + with tf_session.Session( + self._master, config=self._session_config) as sess: + sess.run(tpu.initialize_system(job=self._master_job)) + logging.info('Initialized TPU in %d seconds', time.time() - start) + session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - logging.info('Initialized TPU in %d seconds', time.time() - start) self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) @@ -521,13 +530,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): + def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None, + master=None, session_config=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) + rendezvous=rendezvous, + master=master, + session_config=session_config) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -2169,7 +2181,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, export_tags=None, @@ -2184,7 +2195,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables, mode=mode, export_tags=export_tags, @@ -2201,7 +2211,6 @@ class TPUEstimator(estimator_lib.Estimator): builder, input_receiver_fn_map, checkpoint_path, - strip_default_attrs, save_variables=False, mode=mode, export_tags=export_tags, @@ -2225,7 +2234,7 @@ class TPUEstimator(estimator_lib.Estimator): def computation(): """Compute tpu tensors used in export_outputs. - Passed to rewrite_for_inference so that model_fn will be called under + Passed to rewrite so that model_fn will be called under the rewriting contexts. Only tpu tensors are returned, but export_outputs and scaffold are captured. @@ -2234,7 +2243,7 @@ class TPUEstimator(estimator_lib.Estimator): outside_compilation. """ # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite_for_inference`. + # so that building the graph will happen under `rewrite`. mode = model_fn_lib.ModeKeys.PREDICT estimator_spec = self._call_model_fn(features, labels, mode, config) @@ -2251,7 +2260,7 @@ class TPUEstimator(estimator_lib.Estimator): capture.capture((estimator_spec, tensors_dict, tensors)) return tpu_tensors - tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) + tpu_tensors_on_cpu = tpu.rewrite(computation) estimator_spec, tensors_dict, tensors = capture.get() # Reconstruct `tensors`, but with `tpu_tensors` replaced with @@ -2564,6 +2573,8 @@ class TPUEstimator(estimator_lib.Estimator): run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), rendezvous=self._rendezvous[mode], + master=self._config.master, + session_config=self._session_config, ), InstallSignalHandlerHook() ]) @@ -2666,8 +2677,10 @@ class TPUEstimator(estimator_lib.Estimator): eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode]), - ] + input_hooks + rendezvous=self._rendezvous[mode], + master=self._config.evaluation_master, + session_config=self._session_config, + )] + input_hooks if eval_hooks: hooks.extend(eval_hooks) @@ -2738,7 +2751,9 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], + master=self._config.master, + session_config=self._session_config), ] + input_hooks if prediction_hooks: @@ -2783,7 +2798,7 @@ def _export_output_to_tensors(export_output): elif isinstance(export_output, export_output_lib.RegressionOutput): return [export_output.value] elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output.outputs.values() + return list(export_output.outputs.values()) else: raise ValueError( '`export_output` must be have type `ClassificationOutput`, ' @@ -3059,7 +3074,7 @@ class _Inputs(object): @staticmethod def from_input_fn(return_values): """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.Dataset): + if isinstance(return_values, dataset_ops.DatasetV2): dataset = return_values return _Inputs(dataset=dataset) @@ -3084,7 +3099,7 @@ class _Inputs(object): The initializer must be run before calling `features_and_labels`. """ - self._iterator = self._dataset.make_initializable_iterator() + self._iterator = dataset_ops.make_initializable_iterator(self._dataset) return self._iterator.initializer def features_and_labels(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py index 3786e52b949dfac8c1587d1ea3041b625f00183f..e3ea983abfd24d03c964fbc647b56262e15e0a96 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.python import data as dataset_lib from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -34,10 +34,10 @@ def make_input_fn(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) dataset = dataset.batch(batch_size) return dataset @@ -50,10 +50,10 @@ def make_input_fn_with_labels(num_samples): def input_fn(params): batch_size = params['batch_size'] - da1 = dataset_lib.Dataset.from_tensor_slices(a) - da2 = dataset_lib.Dataset.from_tensor_slices(b) + da1 = dataset_ops.Dataset.from_tensor_slices(a) + da2 = dataset_ops.Dataset.from_tensor_slices(b) - dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset_ops.Dataset.zip((da1, da2)) dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) dataset = dataset.batch(batch_size) return dataset @@ -71,7 +71,7 @@ class TPUEstimatorStoppingSignalsTest(test.TestCase): with ops.Graph().as_default(): dataset = input_fn(params) - features = dataset.make_one_shot_iterator().get_next() + features = dataset_ops.make_one_shot_iterator(dataset).get_next() # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. self.assertIsNone(features['a'].shape.as_list()[0]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index e75a09492ec12b95bad32b221a8e78a1b79f3a6b..d5957b7e8ec40b40c7af8822378cee6134ef0d0f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -92,8 +91,7 @@ class InfeedQueue(object): else: raise ValueError( "number of tuple elements cannot be inferred from InfeedQueue " - "constructor" - ) + "constructor") if number_of_tuple_elements <= 0: raise ValueError("number_of_tuple_elements %d must be > 0" % number_of_tuple_elements) @@ -293,9 +291,8 @@ class InfeedQueue(object): self.number_of_tuple_elements """ if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s, but should be a list of %d Tensors", ( - str(input_tensors), self.number_of_tuple_elements)) + raise ValueError("input_tensors is %s, but should be a list of %d Tensors" + % (str(input_tensors), self.number_of_tuple_elements)) self.set_tuple_shapes([t.shape for t in input_tensors]) self.set_tuple_types([t.dtype for t in input_tensors]) @@ -451,8 +448,8 @@ class InfeedQueue(object): for i in xrange(1, self.number_of_tuple_elements): if devices[0] != devices[i]: raise ValueError( - "input devices for shard %d are %s, but should all be the same", - index, str(devices)) + "input devices for shard %d are %s, but should all be the same" % + (index, str(devices))) with ops.colocate_with(inputs[0]): return tpu_ops.infeed_enqueue_tuple( inputs=inputs, @@ -792,18 +789,14 @@ class _PartitionedInfeedQueue(InfeedQueue): Args: tensor: Input tensor for partitioning. - dims: A list of integer describes how to partition the input tensor. + dims: 1-D np.array of the list of integer describes how to partition the + input tensor. Raises: ValueError: If the tensor can't be partitioned by dims or the num_cores_per_replica doesn't match the number of partitions(dims.prod()). """ - if dims is None: - return - - dims = np.array(dims) - if (dims < 1).any(): raise ValueError("All input partition dims must be >= 1.") @@ -823,11 +816,6 @@ class _PartitionedInfeedQueue(InfeedQueue): "partition dims = {}).".format(tensor.shape.as_list(), dims)) tensor.shape.assert_is_fully_defined() - if (np.array(tensor.shape.as_list()) % dims != 0).any(): - raise ValueError( - "All input partition dims must divide exactly into the `Tensor` " - "shape (tensor shape = {}, input partition dims = {}).".format( - tensor.shape.as_list(), dims)) def _partition_or_replicate_on_host(self, tensor, dims): """Partitions or replicates the input tensor. @@ -840,16 +828,39 @@ class _PartitionedInfeedQueue(InfeedQueue): Returns: An iterator of `Tensor`s or a list of partioned tensors. """ - self._check_input_partition_dims(tensor, dims) if dims is None: return itertools.repeat(tensor) - else: - output = [tensor] - for axis, dim in enumerate(dims): - if dim > 1: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output + dims = np.array(dims) + self._check_input_partition_dims(tensor, dims) + output = [tensor] + shape_list = np.array(tensor.shape.as_list()) + quotients, remainders = np.divmod(shape_list, dims) + for axis, (quotient, remainder, dim, original_size) in enumerate( + zip(quotients, remainders, dims, shape_list)): + if dim <= 1: + continue + if remainder > 0: + # For each dimension, when it cannot be evenly partitioned, XLA assumes + # tensors are partitioned in a greedy manner by using + # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims + # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => + # [[(3, 4), (3, 4), (2, 4), (2, 2)], + # [(2, 4), (2, 4), (2, 4), (2, 2)]] + ceil_ratio = quotient + 1 + num_full_slots, left_over = np.divmod(original_size, ceil_ratio) + num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] + if len(num_or_size_splits) < dim: + num_or_size_splits += [0] * (dim - len(num_or_size_splits)) + new_output = [] + for x in output: + new_output.append( + array_ops.split( + x, num_or_size_splits=num_or_size_splits, axis=axis)) + output = new_output + else: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. @@ -866,13 +877,9 @@ class _PartitionedInfeedQueue(InfeedQueue): elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: - tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, - tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( - dtype=np.dtype(tensor.dtype.as_numpy_dtype), - shape_tuple=tile_shape), tile_assignment=tile_assignment) def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index ec682e5829c4df536a043334b74200f0b6259df3..d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -52,6 +52,7 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) + # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info('Querying Tensorflow master (%s) for TPU system metadata.', diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index b6c350ecd7588221b0e7bc979ed1be3b911c8cfd..0187b4bec6ecc55943bf48b9268a74e18ea5b488 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -166,8 +166,8 @@ def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): # control dependencies from any side-effecting operations. if input_arity == 0: inputs = [array_ops.constant(0)] - return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs, - name="") + return control_flow_ops.while_loop( + condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) def repeat(n, body, inputs=None, infeed_queue=None, name=None): diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md index b6514e19dc92fe4c7cdcdb6582a7c0ad5ad573d5..552febd80bd35b37a95cdaaf8d5923278311ac8e 100644 --- a/tensorflow/contrib/tpu/tpu_estimator.md +++ b/tensorflow/contrib/tpu/tpu_estimator.md @@ -89,12 +89,9 @@ handle training: dataset = tf.data.TFRecordDataset( filename, buffer_size=FLAGS.dataset_reader_buffer_size) - dataset = dataset.map(parser).cache().repeat().batch(batch_size) - images, labels = dataset.make_one_shot_iterator().get_next() - # set_shape to give inputs statically known shapes. - images.set_shape([batch_size, 28 * 28]) - labels.set_shape([batch_size]) - return images, labels + dataset = dataset.map(parser).cache().repeat().batch( + batch_size, drop_remainder=True) + return dataset return input_fn diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 00295f57f60858db5234ce28cc643ea9eee44daa..f6427ae05a20f253edf030eff0f860361616042b 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,7 +26,6 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", - "python/training/tensor_queue_dataset.py", "python/training/training.py", "python/training/tuner.py", ], @@ -287,28 +286,6 @@ py_test( ], ) -py_test( - name = "tensor_queue_dataset_test", - size = "large", - srcs = ["python/training/tensor_queue_dataset_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base", - "//third_party/py/numpy", - ], -) - tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 3547e71184ec2b99163ea4247c01d24487811b47..87ce57ef060a0eb9383248255713421c14988416 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -59,8 +59,6 @@ from tensorflow.contrib.training.python.training.hparam import * from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * -from tensorflow.contrib.training.python.training.tensor_queue_dataset import enqueue_in_queue_dataset -from tensorflow.contrib.training.python.training.tensor_queue_dataset import prepend_from_queue_and_padded_batch_dataset from tensorflow.contrib.training.python.training.training import add_gradients_summaries from tensorflow.contrib.training.python.training.training import clip_gradient_norms from tensorflow.contrib.training.python.training.training import clip_gradient_norms_fn @@ -79,7 +77,6 @@ _allowed_symbols = [ 'FeedingQueueRunner', 'get_or_create_eval_step', 'StopAfterNEvalsHook', 'SummaryAtEndHook', 'wait_for_new_checkpoint', 'add_gradients_summaries', 'clip_gradient_norms', 'clip_gradient_norms_fn', 'create_train_op', - 'multiply_gradients', 'enqueue_in_queue_dataset', - 'prepend_from_queue_and_padded_batch_dataset', 'train'] + 'multiply_gradients', 'train'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py deleted file mode 100644 index 8896a95327a4cb609a9a78412afa68b316a3131e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.util import nest as tf_nest - - -class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset): - """A `Dataset` that prepends a queue to another `Dataset`. - - A vector of handles to the queue is returned as the first component of - the associated iterator. This vector can be passed to - `enqueue_in_queue_dataset` to add new elements to the queue. - """ - - def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): - """Initialize `PrependFromQueueAndPaddedBatchDataset`.""" - super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset) - if sparse.any_sparse(input_dataset.output_classes): - raise TypeError( - "Batching of padded sparse tensors is not currently supported") - self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - if padded_shapes is None: - self._padded_shapes = nest.map_structure( - convert.partial_shape_to_tensor, input_dataset.output_shapes) - else: - self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, convert.partial_shape_to_tensor, - padded_shapes) - # pylint: disable=protected-access - padding_values = ( - padding_values if padding_values is not None else - dataset_ops._default_padding(input_dataset)) - self._padding_values = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._padding_value_to_tensor, - padding_values, input_dataset.output_types) - # pylint: enable=protected-access - - def _as_variant_tensor(self): - # pylint: disable=protected-access - return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset( - self._input_dataset._as_variant_tensor(), - batch_size=self._batch_size, - padded_shapes=[ - ops.convert_to_tensor(s, dtype=dtypes.int64) - for s in nest.flatten(self._padded_shapes) - ], - padding_values=nest.flatten(self._padding_values), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - # pylint: enable=protected-access - - @property - def output_classes(self): - return (ops.Tensor, self._input_dataset.output_classes) - - def _as_batch_shape(self, shape_like): - return tensor_shape.vector(None).concatenate( - tensor_util.constant_value_as_shape(shape_like)) - - @property - def output_shapes(self): - # First output is a variant representing the Queue - return (tensor_shape.vector(None), - nest.map_structure(self._as_batch_shape, self._padded_shapes)) - - @property - def output_types(self): - # First output is a variant representing the Queue - return (dtypes.variant, self._input_dataset.output_types) - - -def prepend_from_queue_and_padded_batch_dataset(batch_size, - padding_values=None, - padded_shapes=None): - """A transformation that prepends a queue to a `Dataset` and batches results. - - A vector of handles to the queue is returned as the first component of the - associated iterator. This vector can be passed to `enqueue_in_queue_dataset` - to add new elements to the queue. - - Below is an example of how this dataset might be used to split incoming - variable-length sequences into "head" and "rest" parts, where "rest" parts - are re-enqueued back into the dataset. A more realistic example would - perform some calculation on the "head" and modify some components of "rest" - with the result (before re-enqueueing). - - ```python - dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map(lambda count: (count, tf.ones((count,)))) - # Emit a queue we can prepend to, and counts/values as padded batch. - dataset = dataset.apply( - tf.contrib.training.prepend_from_queue_and_padded_batch_dataset( - batch_size=10)) - dataset = dataset.prefetch(1) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = tf.squeeze(tf.where(count > 3), axis=1) - bound = tf.minimum(3, tf.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = tf.gather(count - 3, rest_indices) - value_rest = tf.gather(padded_value[:, bound:], rest_indices) - queue_rest = tf.gather(queue, rest_indices) - enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset( - queue_rest, (count_rest, value_rest)) - with tf.control_dependencies([enqueue_rest_op]): - calculation = fn(value_head) - - while True: # Will raise OutOfRange when finished with all pieces. - session.run(calculation) - ``` - - Args: - batch_size: `int64` scalar tensor. The batch size to use when performing - padded batching. - padding_values: (optional) Nested tuple of scalar tensors. If provided, - the structure and dtypes of padding_values should match that of - incoming dataset's `output_types`. - padded_shapes: (optional) Nested tuple of `int64` vector tensors. - If provided, the structure must match that of the incoming dataset's - `output_types`. If not provided, the incoming dataset's `output_shapes` - is used. Any unknown (`None` or `-1`) dimensions in the shapes are - treated as being unique per-batch: for each batch time, an unknown - dimension is replaced with the maximum given value of this dimension - across all tensors for the given component in the batch. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrependFromQueueAndPaddedBatchDataset( - dataset, - batch_size=batch_size, - padding_values=padding_values, - padded_shapes=padded_shapes) - - return _apply_fn - - -def enqueue_in_queue_dataset(queue, components): - """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`. - - The components' dtypes and shapes must be compatible with the `output_shapes` - attribute of the `dataset` created by - `prepend_from_queue_and_padded_batch_dataset`. This operation supports both - non-batched and batched modes. - - For more details, see the example in the docstring for - `prepend_from_queue_and_padded_batch_dataset`. - - Args: - queue: `variant` scalar or vector tensor. - The tensor emitted by the first component of the iterator associated with - `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar, - then the `components` input tensors should not have a prepended batch - dimension. - components: Nested tuple of tensors, each with a leading batch dimension - if `queue` is a vector. The structure, dtypes, and shapes - (excluding batch dimension) must match the nested tuples - `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue - output types and shapes) of the `dataset` emitted by - the original `prepend_from_queue_and_padded_batch_dataset` call. - - Returns: - An `Operation` that enqueues `components` into the dataset(s) associated - with entries of `queue`. - """ - return gen_dataset_ops.enqueue_in_queue_dataset( - queue=queue, components=tf_nest.flatten(components)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py deleted file mode 100644 index c1657fec7bbe4a3227c3ea273b72176ac4066c50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TensorQueueDataset.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd -from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import string_ops -from tensorflow.python.platform import test - - -class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): - - def testNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types) - self.assertAllEqual(([None],) * 2, - [x.as_list() for x in dataset.output_shapes]) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertEqual([0], self.evaluate(value)) - self.assertEqual([1], self.evaluate(value)) - self.assertEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([0, 1], self.evaluate(value)) - self.assertAllEqual([2], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingNoEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=2, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value)) - self.assertAllEqual([[2, 0, 0]], self.evaluate(value)) - with self.assertRaisesOpError("End of sequence"): - self.evaluate(value) - - def testBatchedWithBiggerPaddingOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[3])) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertAllEqual([[0, 0, 0]], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[1, 0, 0]], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[-1, 0, 0]], value_2) - value_3 = sess.run(value) - self.assertAllEqual([[1, 0, 0]], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([[2, 0, 0]], value_4) - value_5 = sess.run(value) - self.assertAllEqual([[-2, 0, 0]], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.cached_session() as sess: - self.assertEqual([0], sess.run(value)) - value_1, _ = sess.run([value, enqueue_negative]) - self.assertEqual([1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertEqual([-1], value_2) - value_3 = sess.run(value) - self.assertEqual([1], value_3) - value_4, _ = sess.run([value, enqueue_negative]) - self.assertEqual([2], value_4) - value_5 = sess.run(value) - self.assertEqual([-2], value_5) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testBatchedOneEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], - array_ops.expand_dims( - value[0], axis=0)) - with self.cached_session() as sess: - value_0, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 1], value_0) - value_1, _ = sess.run([value, enqueue_zeroth]) - self.assertAllEqual([0, -1], value_1) - value_2, _ = sess.run([value, enqueue_negative]) - self.assertAllEqual([0, 2], value_2) - self.assertAllEqual([0, -2], sess.run(value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testManyEnqueue(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue_many_more = [ - tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) - for i in range(1000) - ] - with self.cached_session() as sess: - value_0, _ = sess.run((value, enqueue_many_more)) - self.assertEqual([0], value_0) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i] for i in range(1000)], sorted(rest)) - # Going back to the original input. - value_1, _ = sess.run((value, enqueue_many_more)) - self.assertEqual(1, value_1) - rest = [] - for _ in range(1000): - rest.append(sess.run(value)) - self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest)) - with self.assertRaisesOpError("End of sequence"): - sess.run(value) - - def testEnqueueWithPrefetch(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - # Prefetching will request additional values before they are - # available to the queue. - dataset = dataset.prefetch(buffer_size=3) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) - with self.cached_session() as sess: - i = 0 - while i < 4: - received, _ = sess.run((value, enqueue)) - if received.size > 0: - self.assertAllEqual([i], received) - i += 1 - received_last = False - while True: - try: - received = sess.run(value) - if received.size > 0: - self.assertAllEqual([4], received) - received_last = True - except errors.OutOfRangeError: - break - self.assertTrue(received_last) - - def testDatasetWithPaddedShapeSmallerThanInputFails(self): - dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=1, padded_shapes=[2])) - iterator = dataset.make_one_shot_iterator() - _, value = iterator.get_next() - with self.cached_session() as sess: - with self.assertRaisesOpError( - r"Incompatible input shapes at component 0 between " - r"input dataset this dataset: \[3\] vs. \[2\]"): - sess.run(value) - - def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - iterator = dataset.make_one_shot_iterator() - queue_handle, value = iterator.get_next() - - enqueue_bad_structure = tqd.enqueue_in_queue_dataset( - queue_handle, (value, value)) - enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [1.0], - dtype=np.float32)) - enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset( - queue_handle, ([1],)) - enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle, - np.array( - [[1]], dtype=np.int32)) - - with self.cached_session() as sess: - with self.assertRaisesOpError( - "mismatched number of tensors. Queue expects 1 tensors but " - "tried to insert 2"): - sess.run(enqueue_bad_structure) - with self.assertRaisesOpError(r"Expected component 0 to have batched " - r"shape \[1,...\], but saw shape: \[\]"): - sess.run(enqueue_bad_shape_no_batch_dim) - with self.assertRaisesOpError( - r"mismatched shapes at component 0. Attempted to insert tensor " - r"with shape \[1\] but queue expected shape: \[\]"): - sess.run(enqueue_bad_shape) - with self.assertRaisesOpError( - r"mismatched dtypes at component 0. Attempted to insert tensor " - r"of type float but queue expected type: int32"): - sess.run(enqueue_bad_dtype) - - def testEnqueueWithPaddedBatchFailsWithInformativeError(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2]) - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1)) - with self.assertRaisesRegexp( - TypeError, r"Unable to create padding for field of type 'variant'"): - dataset.padded_batch(batch_size=10, padded_shapes=[1]) - - def testOneEnqueueWithPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3)) - - iterator = dataset.make_one_shot_iterator() - queue, (count, padded_value) = iterator.get_next() - - # Split the padded_value into two pieces: head and rest - rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1) - bound = math_ops.minimum(2, math_ops.reduce_max(count)) - value_head = padded_value[:, :bound] - count_rest = array_ops.gather(count - 2, rest_indices) - value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:] - queue_rest = array_ops.gather(queue, rest_indices) - enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest, - (count_rest, value_rest)) - with ops.control_dependencies([enqueue_rest_op]): - calc = array_ops.identity(value_head) - - with self.cached_session() as sess: - self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) - self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - self.assertAllEqual([[6, 6]], sess.run(calc)) - # Get some final batches due to prefetching. - for _ in range(3): - try: - self.assertAllEqual( - np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc)) - except errors.OutOfRangeError as e: - self.assertTrue(str(e).startswith("End of sequence")) - - def testNonstandardPadding(self): - dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6]) - # Make a dataset of variable-length vectors and their lengths. - dataset = dataset.map( - lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype))) - # Emit a queue we can prepend to, and counts/values as padded - # batch. - dataset = dataset.apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=3, padding_values=( - 0, - -1, - ))) - - iterator = dataset.make_one_shot_iterator() - _, (unused_count, padded_value) = iterator.get_next() - - with self.cached_session() as sess: - self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], - sess.run(padded_value)) - self.assertAllEqual([[6] * 6], sess.run(padded_value)) - with self.assertRaisesOpError("End of sequence"): - sess.run(padded_value) - - -# TODO(ebrevdo): Figure out how to use run_core_tests to test state -# saving of an iterator that's had some tensors enqueued into its queue. -class PrependFromQueueAndPaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPrependFromQueueAndPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).apply( - tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4)) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).apply( - tqd.prepend_from_queue_and_padded_batch_dataset( - batch_size=4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, ""))) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index f7c979e86320d59ad033e2b8d7fcdff89ce0d133..9db80f6b5736d849d88e1e41ea467a5ff11844f5 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -1028,7 +1027,10 @@ Status RdmaTensorResponse::PrepareRecvTensor( return errors::Aborted( "RecvTensor expects a different device incarnation: ", parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), - ". Your worker job was probably restarted. Check your " + ". Your worker job (\"", + channel_->adapter_->worker_env_->session_mgr->LegacySession() + ->worker_name, + "\") was probably restarted. Check your " "worker job for the reason why it was restarted."); } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a701b38d4b3e736a72f20084dbaa6489f1232fb0..66714235b535c14a8f13c40bb2a4df8d7494dc05 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -95,7 +95,8 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") -load("//tensorflow:tensorflow.bzl", "if_not_tx2_llvm_or_windows_cuda") +load("//tensorflow:tensorflow.bzl", "if_nccl") +load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config @@ -112,6 +113,7 @@ load( "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", + "tf_additional_logger_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", @@ -300,6 +302,7 @@ filegroup( "platform/env_time.h", "platform/logging.h", "platform/macros.h", + "platform/platform_strings.h", "platform/types.h", ], visibility = ["//visibility:private"], @@ -442,6 +445,18 @@ cc_library( ] + tf_additional_human_readable_json_deps(), ) +cc_library( + name = "logger", + srcs = tf_platform_srcs(["logger.cc"]), + hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]), + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_logger_deps(), +) + filegroup( name = "platform_env_hdrs", srcs = [ @@ -477,7 +492,10 @@ cc_library( ":platform_env_internal_hdrs", ], copts = tf_copts(), - visibility = ["//tensorflow/core:__subpackages__"], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core:__subpackages__", + ], deps = [ ":error_codes_proto_cc", ":lib", @@ -519,6 +537,19 @@ cc_library( ], ) +cc_library( + name = "platform_strings", + srcs = tf_platform_srcs([ + "platform/platform_strings.cc", + "platform/platform_strings_computed.h", + ]), + hdrs = [ + "platform/platform_strings.h", + ], + visibility = ["//tensorflow/core:__subpackages__"], + deps = [":lib"], +) + filegroup( name = "platform_other_hdrs", srcs = [ @@ -841,6 +872,7 @@ tf_cuda_library( "framework/dataset_stateful_op_whitelist.h", "framework/device_base.h", "framework/function.h", + "framework/function_handle_cache.h", "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", @@ -884,6 +916,7 @@ tf_cuda_library( "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", + "util/dump_graph.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", @@ -901,6 +934,7 @@ tf_cuda_library( "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", + "util/tensor_ops_util.h", "util/tensor_slice_reader.h", "util/tensor_slice_reader_cache.h", "util/tensor_slice_writer.h", @@ -1038,6 +1072,7 @@ tf_gen_op_libs( "batch_ops", "bitwise_ops", "boosted_trees_ops", + "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", "collective_ops", @@ -1085,7 +1120,11 @@ tf_gen_op_libs( op_lib_names = [ "string_ops", ], - deps = ["@com_google_absl//absl/strings"], + deps = [ + ":lib_internal", + ":lib_proto_parsing", + "@com_google_absl//absl/strings", + ], ) tf_gen_op_libs( @@ -1187,6 +1226,7 @@ cc_library( ":batch_ops_op_lib", ":bitwise_ops_op_lib", ":boosted_trees_ops_op_lib", + ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", ":collective_ops_op_lib", @@ -1340,6 +1380,7 @@ cc_library( "//tensorflow/core/kernels:batch_kernels", "//tensorflow/core/kernels:bincount_op", "//tensorflow/core/kernels:boosted_trees_ops", + "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", "//tensorflow/core/kernels:collective_ops", @@ -1386,9 +1427,7 @@ cc_library( "//tensorflow/core/kernels:summary_kernels", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", - ] + tf_additional_cloud_kernel_deps() + if_not_tx2_llvm_or_windows_cuda([ - "//tensorflow/core/kernels:nccl_kernels", - ]) + if_not_windows([ + ] + tf_additional_cloud_kernel_deps() + if_not_windows([ "//tensorflow/core/kernels:fact_op", "//tensorflow/core/kernels:array_not_windows", "//tensorflow/core/kernels:math_not_windows", @@ -1413,6 +1452,8 @@ cc_library( ]) + if_cuda([ "//tensorflow/core/grappler/optimizers:gpu_swapping_kernels", "//tensorflow/core/grappler/optimizers:gpu_swapping_ops", + ]) + if_nccl([ + "//tensorflow/core/kernels:nccl_kernels", ]), ) @@ -1437,7 +1478,7 @@ tf_cuda_library( ":gpu_runtime", ":lib", ":ops", - ], + ] + tensorflow_opensource_extra_deps(), ) cc_library( @@ -1577,6 +1618,8 @@ filegroup( "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", + "platform/**/logger.cc", + "platform/**/logger.h", "platform/default/test_benchmark.*", "platform/cuda.h", "platform/google/**/*", @@ -1639,6 +1682,9 @@ filegroup( # operators, use :android_tensorflow_lib if you want full operator # support. # +# If you just need TensorFlow types, e.g. Tensors, use +# :android_tensorflow_lib_lite_no_runtime. +# # Compiles to a trivial library on non-Android to prevent irrelevant # build errors. If not building this as part of an android_binary, # a command such as the following must be used: @@ -1649,7 +1695,33 @@ filegroup( cc_library( name = "android_tensorflow_lib_lite", srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":mobile_additional_lib_deps", + ":protos_all_cc_impl", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@protobuf_archive//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "android_tensorflow_lib_lite_nortti", + srcs = if_android(["//tensorflow/core:android_srcs"]), + copts = tf_copts(android_optimization_level_override = None) + [ + "-DSUPPORT_SELECTIVE_REGISTRATION", + ] + tf_opts_nortti_if_android(), linkopts = ["-lz"], tags = [ "manual", @@ -1671,8 +1743,8 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", ], ) @@ -1761,50 +1833,21 @@ cc_library( # Does not contain operators. In contrast to android_tensorflow_lib_lite, # this links in framework support for all types, relying on selective # registration of ops to prune code size. -cc_library( +# +# TODO(gonnet): Move all users of these aliases to the corresponding +# :android_tensorflow_lib_lite* targets and remove. +alias( name = "android_tensorflow_lib_selective_registration", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], + actual = ":android_tensorflow_lib_lite", visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, ) # Android library for use with the SELECTIVE_REGISTRATION feature with # no proto_rtti. -cc_library( +alias( name = "android_tensorflow_lib_selective_registration_nortti", - srcs = if_android(["//tensorflow/core:android_srcs"]), - copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ - "-DSUPPORT_SELECTIVE_REGISTRATION", - ], - linkopts = if_android(["-lz"]), - tags = [ - "manual", - "notap", - ], + actual = ":android_tensorflow_lib_lite_nortti", visibility = ["//visibility:public"], - deps = [ - ":protos_all_cc_impl", - "//third_party/eigen3", - "@double_conversion//:double-conversion", - "@nsync//:nsync_cpp", - "@protobuf_archive//:protobuf", - ], - alwayslink = 1, ) filegroup( @@ -2045,9 +2088,7 @@ tf_proto_library_cc( srcs = ["protobuf/master.proto"], cc_api_version = 2, protodeps = tf_additional_all_protos(), - visibility = [ - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], ) tf_proto_library_cc( @@ -2187,6 +2228,7 @@ cc_library( "platform/**/env_time.cc", "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", + "platform/**/logger.cc", "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", @@ -2199,6 +2241,7 @@ cc_library( "platform/**/stream_executor.h", "platform/**/env_time.cc", "platform/**/device_tracer.cc", + "platform/**/logger.cc", "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", @@ -2641,6 +2684,8 @@ tf_cuda_library( ":stats_calculator_portable", ":version_lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", "//third_party/eigen3", @@ -2943,6 +2988,7 @@ tf_cuda_library( ":lib_internal", ":proto_text", ":protos_all_cc", + "@com_google_absl//absl/memory", "//third_party/eigen3", "//tensorflow/core/grappler:grappler_item", ] + mkl_deps(), @@ -3008,7 +3054,6 @@ tf_cuda_library( hdrs = ["common_runtime/metrics.h"], deps = [ ":lib", - "@com_google_absl//absl/time", ], ) @@ -3033,7 +3078,6 @@ tf_cuda_library( ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", - "@com_google_absl//absl/time", ], alwayslink = 1, ) @@ -3393,6 +3437,7 @@ tf_cc_tests( "platform/profile_utils/cpu_utils_test.cc", "platform/stacktrace_handler_test.cc", "platform/subprocess_test.cc", + "platform/vmodule_benchmark_test.cc", ], deps = [ ":lib", @@ -3406,6 +3451,20 @@ tf_cc_tests( ], ) +tf_cc_test( + name = "vmodule_test", + srcs = ["platform/vmodule_test.cc"], + tags = ["optonly"], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "lib_random_random_distributions_test", srcs = ["lib/random/random_distributions_test.cc"], @@ -3421,6 +3480,16 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_strings_test", + size = "small", + srcs = ["platform/platform_strings_test.cc"], + deps = [ + ":lib", + ":platform_strings", + ], +) + tf_cc_test( name = "platform_env_test", size = "small", @@ -3668,6 +3737,7 @@ tf_cc_tests( "util/bcast_test.cc", "util/command_line_flags_test.cc", "util/device_name_utils_test.cc", + "util/dump_graph_test.cc", "util/equal_graph_def_test.cc", "util/events_writer_test.cc", "util/example_proto_fast_parsing_test.cc", @@ -3798,6 +3868,7 @@ tf_cc_tests_gpu( ":test", ":test_main", ":testlib", + "@com_google_absl//absl/memory", ], ) @@ -3826,6 +3897,7 @@ tf_cc_tests_gpu( ":test", ":test_main", ":testlib", + "@com_google_absl//absl/memory", ], ) @@ -4099,6 +4171,7 @@ tf_cc_test( "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:topk_op", "//third_party/eigen3", ], ) @@ -4392,6 +4465,7 @@ tf_cc_test( "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:shape_ops", "//third_party/eigen3", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -4871,6 +4945,7 @@ transitive_hdrs( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:platform_strings", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor", ], diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 6f9885691595368ab50cfe660b1b5c75673063cf..7405e2ace72d1c08cf87cc0040e617379e18149b 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { namespace { @@ -182,11 +181,14 @@ void TestDeprecationVersionSetCorrectly( for (const auto& name_and_api_def : api_defs_map) { const auto& name = name_and_api_def.first; const auto& api_def = name_and_api_def.second; - ASSERT_TRUE(api_def.deprecation_version() == 0 || - api_def.deprecation_message().empty()) - << "ApiDef that includes deprecation_version > 0 must also specify " - << "a deprecation_message. Op " << name - << " has deprecation_version > 0 but deprecation_message is not set."; + if (api_def.deprecation_version() != 0) { + ASSERT_TRUE(api_def.deprecation_version() > 0) + << "Found ApiDef with negative deprecation_version"; + ASSERT_FALSE(api_def.deprecation_message().empty()) + << "ApiDef that includes deprecation_version > 0 must also specify " + << "a deprecation_message. Op " << name + << " has deprecation_version > 0 but deprecation_message is not set."; + } } } } // namespace diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt index 639d962874d083472e6df13550e107026fd2d0a1..32def912f83e420eab58a3071f573ae81139a298 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchDataset.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "BatchDataset" + visibility: HIDDEN in_arg { name: "batch_size" description: <