diff --git a/WORKSPACE b/WORKSPACE index 957b8d8528dc9b5e2ea134921b28601aa6fed2d1..9f07b9fd47136d058cc4039ed6948db539485039 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -29,7 +29,7 @@ load( bazel_toolchains_repositories() load( - "@io_bazel_rules_docker//container:container.bzl", + "@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories", ) @@ -43,29 +43,17 @@ remote_config_workspace() # Apple and Swift rules. http_archive( name = "build_bazel_rules_apple", - sha256 = "4fe4ee824200b48821730f89ff260984332dc3551db587c24691235d1d96a8a7", - strip_prefix = "rules_apple-0.10.0", - urls = ["https://github.com/bazelbuild/rules_apple/archive/0.10.0.tar.gz"], -) -http_archive( - name = "build_bazel_rules_swift", - sha256 = "6544ff5615febec0342de1127144d2f3e43ea80fb7f9b1ade65e6a184e39e618", - strip_prefix = "rules_swift-0.5.0", - urls = ["https://github.com/bazelbuild/rules_swift/archive/0.5.0.tar.gz"], -) -http_archive( - name = "bazel_skylib", - sha256 = "eb5c57e4c12e68c0c20bc774bfbc60a568e800d025557bc4ea022c6479acc867", - strip_prefix = "bazel-skylib-0.6.0", - urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.6.0.tar.gz"], + sha256 = "73b4980a318d203d3307f850e27e66ec5cc8d223147a3475a6f11597eb6438a5", + strip_prefix = "rules_apple-0.13.0", + urls = ["https://github.com/bazelbuild/rules_apple/archive/0.13.0.tar.gz"], ) http_file( name = "xctestrunner", executable = 1, - urls = ["https://github.com/google/xctestrunner/releases/download/0.2.5/ios_test_runner.par"], + urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"], ) load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") -apple_rules_dependencies(ignore_version_differences = True) +apple_rules_dependencies() load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") swift_rules_dependencies() @@ -134,4 +122,3 @@ http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) - diff --git a/configure.py b/configure.py index 8dcd31822000820df12c7e96f5c57c68ed605f41..3eb09a1ae905b70dc5d02fab7c316f73c79633dd 100644 --- a/configure.py +++ b/configure.py @@ -55,6 +55,12 @@ NCCL_LIB_PATHS = [ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', '' ] +# List of files to be configured for using Bazel on Apple platforms. +APPLE_BAZEL_FILES = [ + 'tensorflow/lite/experimental/objc/BUILD', + 'tensorflow/lite/experimental/swift/BUILD' +] + if platform.machine() == 'ppc64le': _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/' else: @@ -256,6 +262,7 @@ def reset_tf_configure_bazelrc(): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() + def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -785,8 +792,7 @@ def set_gcc_host_compiler_path(environ_cp): environ_cp, var_name='GCC_HOST_COMPILER_PATH', var_default=default_gcc_host_compiler_path, - ask_for_var= - 'Please specify which gcc should be used by nvcc as the host compiler.', + ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.', check_success=os.path.exists, error_msg='Invalid gcc path. %s cannot be found.', ) @@ -1237,6 +1243,7 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1273,13 +1280,15 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities = ( 'Please specify a list of comma-separated ' - 'Cuda compute capabilities you want to ' + 'CUDA compute capabilities you want to ' 'build with.\nYou can find the compute ' 'capability of your device at: ' 'https://developer.nvidia.com/cuda-gpus.\nPlease' ' note that each additional compute ' 'capability significantly increases your ' - 'build time and binary size. [Default is: %s]: ' % + 'build time and binary size, and that ' + 'TensorFlow only supports compute ' + 'capabilities >= 3.5 [Default is: %s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', @@ -1292,12 +1301,14 @@ def set_tf_cuda_compute_capabilities(environ_cp): for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: - print('Invalid compute capability: ' % compute_capability) + print('Invalid compute capability: %s' % compute_capability) all_valid = False else: - ver = int(m.group(0).split('.')[0]) - if ver < 3: - print('Only compute capabilities 3.0 or higher are supported.') + ver = float(m.group(0)) + if ver < 3.5: + print('ERROR: TensorFlow only supports CUDA compute capabilities 3.5 ' + 'and higher. Please re-specify the list of compute ' + 'capabilities excluding version %s.' % ver) all_valid = False if all_valid: @@ -1484,6 +1495,34 @@ def set_other_mpi_vars(environ_cp): 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % (mpi_home, mpi_home, mpi_home)) +def system_specific_test_config(env): + """Add default test flags required for TF tests to bazelrc.""" + write_to_bazelrc('test --flaky_test_attempts=3') + write_to_bazelrc('test --test_size_filters=small,medium') + write_to_bazelrc( + 'test --test_tag_filters=-benchmark-test,-no_oss,-oss_serial') + write_to_bazelrc('test --build_tag_filters=-benchmark-test,-no_oss') + if is_windows(): + if env.get('TF_NEED_CUDA', None) == 1: + write_to_bazelrc( + 'test --test_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + write_to_bazelrc( + 'test --build_tag_filters=-no_windows,-no_windows_gpu,-no_gpu') + else: + write_to_bazelrc('test --test_tag_filters=-no_windows,-gpu') + write_to_bazelrc('test --build_tag_filters=-no_windows,-gpu') + elif is_macos(): + write_to_bazelrc('test --test_tag_filters=-gpu,-nomac,-no_mac') + write_to_bazelrc('test --build_tag_filters=-gpu,-nomac,-no_mac') + elif is_linux(): + if env.get('TF_NEED_CUDA', None) == 1: + write_to_bazelrc('test --test_tag_filters=-no_gpu') + write_to_bazelrc('test --build_tag_filters=-no_gpu') + write_to_bazelrc('test --test_env=LD_LIBRARY_PATH') + else: + write_to_bazelrc('test --test_tag_filters=-gpu') + write_to_bazelrc('test --build_tag_filters=-gpu') + def set_system_libs_flag(environ_cp): syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') @@ -1513,10 +1552,6 @@ def set_windows_build_flags(environ_cp): # The host and target platforms are the same in Windows build. So we don't # have to distinct them. This avoids building the same targets twice. write_to_bazelrc('build --distinct_host_configuration=false') - # Enable short object file path to avoid long path issue on Windows. - # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 - # Short object file path will be enabled by default. - write_to_bazelrc('build --experimental_shortened_obj_file_path=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1537,6 +1572,23 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) +def configure_apple_bazel_rules(): + """Configures Bazel rules for building on Apple platforms. + + Enables analyzing and building Apple Bazel rules on Apple platforms. This + function will only be executed if `is_macos()` is true. + """ + if not is_macos(): + return + for filepath in APPLE_BAZEL_FILES: + print( + 'Configuring %s file to analyze and build Bazel rules on Apple platforms.' + % filepath) + existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') + renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) + os.rename(existing_filepath, renamed_filepath) + + def main(): global _TF_WORKSPACE_ROOT global _TF_BAZELRC @@ -1556,7 +1608,7 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.19.0', '0.21.0') + check_bazel_version('0.19.0', '0.22.0') reset_tf_configure_bazelrc() @@ -1577,6 +1629,8 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' + else: + environ_cp['TF_CONFIGURE_APPLE_BAZEL_RULES'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1679,6 +1733,16 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) + system_specific_test_config(os.environ) + + if get_var( + environ_cp, 'TF_CONFIGURE_APPLE_BAZEL_RULES', + 'Configure Bazel rules for Apple platforms', False, + ('Would you like to configure Bazel rules for building on Apple platforms?' + ), 'Configuring Bazel rules for Apple platforms.', + 'Not configuring Bazel rules for Apple platforms.'): + configure_apple_bazel_rules() + print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') @@ -1687,8 +1751,9 @@ def main(): config_info_line('gdr', 'Build with GDR support.') config_info_line('verbs', 'Build with libverbs support.') config_info_line('ngraph', 'Build with Intel nGraph support.') - config_info_line('dynamic_kernels', - '(Experimental) Build kernels into separate shared objects.') + config_info_line( + 'dynamic_kernels', + '(Experimental) Build kernels into separate shared objects.') print('Preconfigured Bazel build configs to DISABLE default on features:') config_info_line('noaws', 'Disable AWS S3 filesystem support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6bc8403d126a58c1eb6499ab7f224e12c6bc5aa4..f53982f1efc9885cc12dcc672ad819c762aca378 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -94,6 +94,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "emscripten", + values = {"crosstool_top": "//external:android/emscripten"}, + visibility = ["//visibility:public"], +) + config_setting( name = "raspberry_pi_armeabi", values = { @@ -456,8 +462,7 @@ tf_cc_shared_object( "//tensorflow:darwin": [], "//tensorflow:windows": [], "//conditions:default": [ - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_framework_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", ], }), linkstatic = 1, @@ -491,15 +496,13 @@ tf_cc_shared_object( name = "libtensorflow.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow/c:exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow/c:version_script.lds)", + "-Wl,--version-script,$(location //tensorflow/c:version_script.lds)", ], }), visibility = ["//visibility:public"], @@ -517,14 +520,12 @@ tf_cc_shared_object( name = "libtensorflow_cc.so", linkopts = select({ "//tensorflow:darwin": [ - "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "$(location //tensorflow:tf_exported_symbols.lds)", + "-Wl,-exported_symbols_list,$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//conditions:default": [ "-z defs", - "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "$(location //tensorflow:tf_version_script.lds)", + "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], }), visibility = ["//visibility:public"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index a93799bfe84b0f9c4743e1ad0effd6e69ad7f3f2..ddcacfcbe2d4d8b089f10f1a771384dc8c4fd199 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -26,14 +26,28 @@ import sys as _sys # API IMPORTS PLACEHOLDER +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +# We're using bitwise, but there's nothing special about that. +_API_MODULE = bitwise # pylint: disable=undefined-variable +_current_module = _sys.modules[__name__] +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) +if not hasattr(_current_module, '__path__'): + __path__ = [_tf_api_dir] +elif _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + # pylint: disable=g-bad-import-order from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg="Limited tf.summary API due to missing TensorBoard installation") _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( 'tensorflow_estimator.python.estimator.api._v2.estimator')) -_current_module = _sys.modules[__name__] if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, @@ -42,14 +56,6 @@ if not hasattr(_current_module, 'estimator'): _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=('tensorflow.python.keras.api._v2.keras')) -# Make sure directory containing top level submodules is in -# the __path__ so that "from tensorflow.foo import bar" works. -# We're using bitwise, but there's nothing special about that. -_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable -if not hasattr(_current_module, '__path__'): - __path__ = [_tf_api_dir] -elif _tf_api_dir not in __path__: - __path__.append(_tf_api_dir) # Enable TF2 behaviors from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top @@ -111,5 +117,10 @@ try: except NameError: pass +# Add module aliases +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index eeca8f0d566a6401cb64e4fe3f0ee3c5aeb4ece2..5eb25a81b7f765f551bc4f1b7ba99b35dbc6b7bb 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -70,7 +70,7 @@ _API_MODULE = app # pylint: disable=undefined-variable # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) # pylint: disable=undefined-variable +_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) if not hasattr(_current_module, '__path__'): __path__ = [_tf_api_dir] elif _tf_api_dir not in __path__: diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ef52a28460062b57317b4027ab83479e5e075b5f..ef7863dc0d5cbd57da30baa6e04278c2a0354b25 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -67,6 +67,23 @@ tf_cuda_library( tf_cuda_library( name = "c_api", + hdrs = ["c_api.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api_no_xla", + ":c_api_internal", + ] + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }), +) + +tf_cuda_library( + name = "c_api_no_xla", srcs = [ "c_api.cc", "c_api_function.cc", @@ -75,14 +92,12 @@ tf_cuda_library( "c_api.h", ], copts = tf_copts(), - visibility = ["//visibility:public"], - deps = select({ + visibility = ["//tensorflow/c:__subpackages__"], + deps = [":c_api_internal"] + select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api_internal", "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/cc:gradients", "//tensorflow/cc:ops", @@ -97,13 +112,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/kernels:logging_ops", ], - }) + select({ - "//tensorflow:with_xla_support": [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/jit", - ], - "//conditions:default": [], }), ) @@ -156,8 +166,8 @@ tf_cuda_library( hdrs = ["tf_status_helper.h"], visibility = ["//visibility:public"], deps = [ - ":c_api", ":c_api_internal", + ":c_api_no_xla", "//tensorflow/core:lib", ], ) @@ -213,13 +223,13 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":c_api", + ":c_api_no_xla", ":c_api_internal", ":tf_status_helper", "//tensorflow/core:framework", @@ -346,6 +356,7 @@ tf_cc_test( srcs = ["c_api_function_test.cc"], deps = [ ":c_api", + ":c_api_internal", ":c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 94d9f4a6fa2f14cb3343bdd51b7e4d61944444d0..245d7ba2b186895532953aa61ebfc3fc6bf635a7 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/kernels/logging_ops.h" #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -640,7 +641,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out) { if (out->data != nullptr) { return InvalidArgument("Passing non-empty TF_Buffer is invalid."); @@ -1310,6 +1311,13 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, reinterpret_cast(values), num_values)); } +void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name, + const char* placeholder) { + tensorflow::AttrValue attr_value; + attr_value.set_placeholder(placeholder); + desc->node_builder.Attr(attr_name, attr_value); +} + void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; @@ -2954,4 +2962,11 @@ void TF_DeleteServer(TF_Server* server) { delete server; #endif } + +void TF_RegisterLogListener(void (*listener)(const char*)) { +#ifndef __ANDROID__ + tensorflow::logging::RegisterListener(listener); +#endif +} + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 8031928dac4de2391f0aec46e69d61a137606e4d..051de3a7dc0f8c630b6c81d2cfa960e5279c93c0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -549,6 +549,10 @@ TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, const TF_DataType* values, int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + // Set a 'func' attribute to the specified name. // `value` must point to a string of length `length` bytes. TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, @@ -1310,6 +1314,28 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + // Returns the name of the graph function. // The return value points to memory that is only usable until the next // mutation to *func. @@ -1743,6 +1769,14 @@ TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); // it will be stopped and joined. TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index a8325ce494c4f57fcd7e64b2d233ee4e6666bc4e..7ff4084decc686b067226ecaecf2af29d51d42f2 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -9064,11 +9064,6 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++); auto* desc = TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()); - for (auto* input : op->operation.Inputs()) { - auto symbolic_input = getOrCreateSymbolicTensor(trace_ctx, input, status); - if (!status->status.ok()) return nullptr; - TF_AddInput(desc, symbolic_input); - } VLOG(1) << "Adding attrs."; tensorflow::AttrValueMap attrs; @@ -9077,6 +9072,34 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx, desc->node_builder.Attr(attr.first, attr.second); } + VLOG(1) << "Adding inputs."; + const auto& inputs = op->operation.Inputs(); + size_t inputIndex = 0; + const tensorflow::OpDef& op_def = desc->node_builder.op_def(); + for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) { + // TODO(bgogul): Add support for number attributes. + DCHECK(input_arg.number_attr().empty()) + << "Number attributes is not implemented yet."; + if (input_arg.type_list_attr().empty()) { + auto symbolic_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + TF_AddInput(desc, symbolic_input); + continue; + } + const std::string& type_list_attr = input_arg.type_list_attr(); + const auto& attr_value = attrs[type_list_attr]; + DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList) + << "Type list attribute should be a list!"; + std::vector list_inputs(attr_value.list().type_size()); + for (TF_Output& list_input : list_inputs) { + list_input = + getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status); + if (!status->status.ok()) return nullptr; + } + TF_AddInputList(desc, list_inputs.data(), list_inputs.size()); + } + auto* graph_op = TF_FinishOperation(desc, status); if (!status->status.ok()) return nullptr; diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 354ee5f49f373edbc10e7706aa8776f3cc2a17cd..c54021a7517ebbdd00405cbfa9cee8f3f6616cca 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -446,5 +446,29 @@ TEST_F(AddEagerOpToGraphTest, ListAttributesArePreserved) { TFE_DeleteOp(squeeze); } +TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) { + TFE_TensorHandle* scalar = TestScalarTensorHandle(); + TFE_Op* identityn = TFE_NewOp(eager_ctx_, "IdentityN", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + constexpr size_t kNumInputs = 3; + for (size_t i = 0; i < kNumInputs; ++i) { + TFE_OpAddInput(identityn, scalar, status_); + } + TF_DataType types[kNumInputs] = {TF_FLOAT, TF_FLOAT, TF_FLOAT}; + TFE_OpSetAttrTypeList(identityn, "T", types, kNumInputs); + AddEagerOpToGraphAndCheck( + identityn, [this, kNumInputs](TF_Operation* graph_op) { + EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs); + EXPECT_EQ(TF_OperationInputListLength(graph_op, "input", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + EXPECT_EQ(TF_OperationOutputListLength(graph_op, "output", status_), + kNumInputs); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + }); + TFE_DeleteTensorHandle(scalar); + TFE_DeleteOp(identityn); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 28b9f8df9c873ee394eb6a241dd9ac06ba6c8796..03d65ecefd4a9ba5a23a94ed902dfba6dd4fbda9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -162,6 +162,11 @@ Status FillFunctionBody( const std::vector& body_nodes, const std::unordered_map& tensor_renaming, FunctionDef* fdef) { + std::unordered_set func_attr_names; + for (const auto& func_attr : fdef->signature().attr()) { + func_attr_names.insert(func_attr.name()); + } + std::vector in_edges; std::vector control_edges; for (const Node* node : body_nodes) { @@ -243,6 +248,48 @@ Status FillFunctionBody( if (node->op_def().is_stateful()) { fdef->mutable_signature()->set_is_stateful(true); } + + // If this node has any attributes with placeholder value, add the + // attribute to FunctionDef signature. + for (const auto& iter : node->attrs()) { + if (iter.second.placeholder().empty()) { + continue; + } + + // If we already added the attribute, skip it. + string func_attr_name = iter.second.placeholder(); + if (func_attr_names.find(func_attr_name) != func_attr_names.end()) { + continue; + } + + // This node's attribute is a placeholder value, so it does not have type + // information. We check node's OpDef for attribute type. + string node_attr_name = iter.first; + const OpDef::AttrDef* node_attr_def = nullptr; + for (const auto& node_attr : node->op_def().attr()) { + if (node_attr.name() == node_attr_name) { + node_attr_def = &node_attr; + } + } + if (!node_attr_def) { +#ifdef TENSORFLOW_LITE_PROTOS + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name); +#else + return errors::Unimplemented( + "Placeholder value is not supported for attributes not in OpDef. " + "Attribute: ", + node_attr_name, ", OpDef: ", node->op_def().DebugString()); +#endif + } + OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); + attr_def->set_name(func_attr_name); + attr_def->set_type(node_attr_def->type()); + + func_attr_names.insert(func_attr_name); + } } return Status::OK(); } @@ -255,6 +302,8 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, const std::vector& inputs, const std::vector& outputs, const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, const char* description, FunctionDef* fdef) { if (!output_names.empty()) { DCHECK_EQ(output_names.size(), outputs.size()); @@ -378,6 +427,29 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, fdef->mutable_signature()->set_name(fn_name); } + if (!control_output_names.empty() && + (control_outputs.size() != control_output_names.size())) { + return InvalidArgument( + "Expected number of control outputs (", control_outputs.size(), + ") and the number of control output names (", + control_output_names.size(), ") to match but they do not."); + } + std::unordered_set control_output_names_set; + for (int i = 0; i < control_outputs.size(); ++i) { + string signature_name; + if (!control_output_names.empty()) { + signature_name = control_output_names[i]; + } else { + signature_name = control_outputs[i]->name(); + } + if (!control_output_names_set.insert(signature_name).second) { + return errors::InvalidArgument("Repeated control output name: ", + signature_name); + } + fdef->mutable_signature()->add_control_output(signature_name); + (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name(); + } + return Status::OK(); } @@ -485,14 +557,14 @@ Status ComputeBodyNodes( using tensorflow::Node; using tensorflow::string; -TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, - unsigned char append_hash_to_fn_name, - int num_opers, const TF_Operation* const* opers, - int ninputs, const TF_Output* inputs, - int noutputs, const TF_Output* outputs, - const char* const* output_names, - const TF_FunctionOptions* opts, - const char* description, TF_Status* status) { +TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); // Process inputs. @@ -517,19 +589,34 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, } } + // Process control output names. + std::vector control_output_names_vec; + if (control_output_names) { + control_output_names_vec.reserve(ncontrol_outputs); + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_names_vec.push_back(string(output_names[i])); + } + } + // Compute body nodes. std::vector body_nodes; status->status = tensorflow::ComputeBodyNodes( fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); if (!status->status.ok()) return nullptr; + // Compute body nodes. + std::vector control_output_nodes; + for (int i = 0; i < ncontrol_outputs; ++i) { + control_output_nodes.push_back(&control_outputs[i]->node); + } + // Do the actual function creation. TF_Function* tf_function = new TF_Function(); DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, - input_tensors, output_tensors, output_names_vec, description, - &tf_function->fdef); + input_tensors, output_tensors, output_names_vec, control_output_nodes, + control_output_names_vec, description, &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; @@ -537,6 +624,20 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + const char* description, TF_Status* status) { + return TF_GraphToFunctionWithControlOutputs( + fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, + inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, + description, status); +} + const char* TF_FunctionName(TF_Function* func) { return func->fdef.signature().name().c_str(); } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 73fe73769bc1219ce865149d67d333c53371ccc5..946f8c4a2c3fb25f908d809e00bf579b40a8668b 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -1230,6 +1231,53 @@ void DefineFunction(const char* name, TF_Function** func, ASSERT_NE(*func, nullptr); } +REGISTER_OP("CustomOp") + .Output("output: float32") + .Attr("index: int") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); + +void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s, + const char* name, const char* placeholder, + TF_Operation** op) { + TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name); + TF_SetAttrPlaceholder(desc, "index", placeholder); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation *node1, *node2, *node3; + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1", + &node1); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1", + &node2); + NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2", + &node3); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}}; + func_ = TF_GraphToFunction( + func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1, + /*opers=*/nullptr, 0, inputs, 3, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, /*description=*/nullptr, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(func_, nullptr); + + // Verify that FunctionDef has 2 attributes, "v1" and "v2". + ASSERT_EQ(func_->fdef.signature().attr().size(), 2); + EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1"); + EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int"); + EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2"); + EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int"); +} + TEST_F(CApiFunctionTest, SetGradientAndRun) { // Define the function and its grad DefineFunction(func_name_, &func_); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 73283d775639b297857b2a50007dc7c28b1f39a3..d520b6b76849e562def6abd8be0510d3b4797e8c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -204,7 +204,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); -Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); // Set the shapes and types of the output's handle. // diff --git a/tensorflow/c/c_test.c b/tensorflow/c/c_test.c index b86d8eb8e300e02a3871ecd5f424a82c521b18fc..7468122cd567270c8454f886e478be34c2c15cbf 100644 --- a/tensorflow/c/c_test.c +++ b/tensorflow/c/c_test.c @@ -25,6 +25,16 @@ limitations under the License. #include "tensorflow/c/env.h" #include "tensorflow/c/kernels.h" +// A create function. This will never actually get called in this test, it's +// just nice to know that it compiles. +void* create(TF_OpKernelConstruction* ctx) { + TF_DataType type; + TF_Status* s = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "foobar", &type, s); + TF_DeleteStatus(s); + return NULL; +} + // A compute function. This will never actually get called in this test, it's // just nice to know that it compiles. void compute(void* kernel, TF_OpKernelContext* ctx) { @@ -32,12 +42,7 @@ void compute(void* kernel, TF_OpKernelContext* ctx) { TF_Status* s = TF_NewStatus(); TF_GetInput(ctx, 0, &input, s); TF_DeleteTensor(input); - - TF_DataType type; - TF_OpKernelContext_GetAttrType(ctx, "foobar", &type, s); - TF_DeleteStatus(s); - } // Exercises tensorflow's C API. @@ -80,7 +85,7 @@ int main(int argc, char** argv) { TF_StringStreamDone(s); TF_KernelBuilder* b = - TF_NewKernelBuilder("SomeOp", "SomeDevice", NULL, &compute, NULL); + TF_NewKernelBuilder("SomeOp", "SomeDevice", &create, &compute, NULL); TF_RegisterKernelBuilder("someKernel", b, status); TF_DeleteStatus(status); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 257be6379c09841d1427813a0aa25b10a205016d..282f0da302fac89c6fae9f8b5aa4b3c33ab93532 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -211,6 +211,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core/profiler/rpc/client:capture_profile", "//tensorflow/core:gpu_runtime", ], ) @@ -230,7 +231,6 @@ tf_cuda_cc_test( ":c_api_test_util", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", - "//tensorflow/contrib/tpu/profiler:trace_events_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index af13f487af91594fedd4d5f77592682a6f98c34f..45701c7fcf02d5e6ec464ae10d4d20f20ba1d9f0 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -762,11 +762,13 @@ unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(true); + ctx->context.SetShouldStoreGraphs(true); + ctx->context.SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - ctx->context.SetShouldStoreMetadata(false); + ctx->context.SetShouldStoreGraphs(false); + ctx->context.SetShouldStoreStepStats(false); } } // extern "C" diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index dab17505643e791e6294a64247898ae23769a055..ff798593b5f2f77339b668668ff6dafb9f44a2b3 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" using tensorflow::string; @@ -25,10 +26,14 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { op->operation.ConsumeInput(h->handle); } -TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx) { +TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) { return new TFE_Profiler(ctx); } +bool TFE_ProfilerIsOk(TFE_Profiler* profiler) { + return profiler->profiler->Status().ok(); +} + void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; } void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, @@ -46,7 +51,43 @@ void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler, }; } -void TFE_StartProfilerServer(TFE_Context* ctx, int port) { - auto server_thread = tensorflow::StartProfilerServer(&ctx->context, port); - ctx->context.AddChildThread(std::move(server_thread)); +TFE_ProfilerContext* TFE_NewProfilerContext() { + return new TFE_ProfilerContext; +} + +void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, + TFE_Context* eager_context) { + profiler_context->profiler_context.eager_context = &eager_context->context; +} + +void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { + delete profiler_context; +} + +void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { + // Release child thread intentionally. The child thread can be terminate by + // terminating the main thread. + tensorflow::StartProfilerServer(&context->profiler_context, port).release(); +} + +void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(true); +} + +void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { + ctx->context.SetShouldStoreGraphs(false); +} + +bool TFE_ProfilerClientStartTracing(char* service_addr, char* logdir, + char* worker_list, bool include_dataset_ops, + int duration_ms, int num_tracing_attempts) { + tensorflow::Status s = + tensorflow::profiler::client::ValidateHostPortPair(service_addr); + if (!s.ok()) { + return false; + } + s = tensorflow::profiler::client::StartTracing( + service_addr, logdir, worker_list, include_dataset_ops, duration_ms, + num_tracing_attempts); + return s.ok(); } diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 8c85d0e51695fde09cf0e2bb3930f9173e6cfb54..89523793d37b89ee49c4db844a85f019381ff730 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -25,6 +25,8 @@ extern "C" { TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); +typedef struct TFE_ProfilerContext TFE_ProfilerContext; + // A profiler which will start profiling when creating the object and will stop // when the object is destroyed. It will profile all operations run under the // given TFE_Context. Multiple instance of it can be created, but at most one @@ -32,7 +34,8 @@ TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, // Thread-safety: TFE_Profiler is thread-safe. typedef struct TFE_Profiler TFE_Profiler; -TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_Context* ctx); +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx); +TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler); TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); // The output string is a binary string of tensorflow.tpu.Trace. User can write @@ -42,14 +45,47 @@ TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Return a new profiler context object. +TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void); + +// Set the eager context in TFE_ProfilerServerOptions +TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext( + TFE_ProfilerContext* profiler_context, TFE_Context* eager_context); + +// Destroy a profiler context object. +TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext( + TFE_ProfilerContext* profiler_context); + // Start a profiler grpc server which listens to specified port. It will start -// the server on its own thread. It can be shutdown by destructing TFE_Context. -// Creating multiple profiler server is allowed. The service defined in +// the server on its own thread. It can be shutdown by terminating tensorflow. +// It can be used in both Eager mode and graph mode. Creating multiple profiler +// server is allowed. The service defined in // tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use // tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable // file following // https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. -TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_Context* ctx, int port); +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context, + int port); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + +// Send a grpc request to profiler server (service_addr) to perform on-demand +// profiling and save the result into logdir which can be visualized by +// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set +// include_dataset_opts to false to profile longer traces. It will block the +// caller thread until receives tracing result. +// This API is designed for TensorBoard, for end user, please use +// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following +// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing( + char* service_addr, char* logdir, char* worker_list, + bool include_dataset_ops, int duration_ms, int num_tracing_attempts); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index af55fee66e8708e39626da3b10b6dd2f73af92bb..d85048caa7c7f727271352883cb834a2575bd251 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/cc/profiler/profiler.h" -#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/profiler/trace_events.pb.h" using tensorflow::string; @@ -41,9 +41,12 @@ void ExecuteWithProfiling(bool async) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); - TFE_Profiler* profiler = TFE_NewProfiler(ctx); + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + TFE_Profiler* profiler = TFE_NewProfiler(profiler_context); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_DeleteProfilerContext(profiler_context); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -70,7 +73,7 @@ void ExecuteWithProfiling(bool async) { TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status); TFE_DeleteProfiler(profiler); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - tensorflow::tpu::Trace profile_proto; + profiler::Trace profile_proto; EXPECT_TRUE(profile_proto.ParseFromString( {reinterpret_cast(profiler_result->data), profiler_result->length})); @@ -100,5 +103,27 @@ void ExecuteWithProfiling(bool async) { TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } +TEST(CAPI, MultipleProfilerSession) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(false)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); + TFE_ProfilerContextSetEagerContext(profiler_context, ctx); + + TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context); + EXPECT_TRUE(TFE_ProfilerIsOk(profiler1)); + + TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context); + EXPECT_FALSE(TFE_ProfilerIsOk(profiler2)); + + TFE_DeleteProfiler(profiler1); + TFE_DeleteProfiler(profiler2); + TFE_DeleteProfilerContext(profiler_context); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index b70c0f1c112c675641a023d6c7bf4fa847ee4610..a563e4b8f50f2a90497736f4cb9ca234400bfa04 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -107,9 +107,14 @@ struct TFE_Op { tensorflow::EagerOperation operation; }; +struct TFE_ProfilerContext { + tensorflow::ProfilerContext profiler_context; +}; + struct TFE_Profiler { - TFE_Profiler(TFE_Context* ctx) - : profiler(tensorflow::ProfilerSession::Create(&ctx->context)) {} + TFE_Profiler(TFE_ProfilerContext* ctx) { + profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); + } std::unique_ptr profiler; }; diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 9505bf9dda32b9a338b574f1d31ec555a5628c6a..71181ae430ab64106e2a75937bd54fbf2efc61ac 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -173,9 +173,10 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { cc_ctx->CtxFailure(s); } -#define DEFINE_TF_GETATTR_(struct_name, func, c_type, cc_type) \ - void struct_name##_GetAttr##func(struct_name* ctx, const char* attr_name, \ - c_type* val, TF_Status* status) { \ +#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ + void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ + const char* attr_name, \ + c_type* val, TF_Status* status) { \ TF_SetStatus(status, TF_OK, ""); \ cc_type v; \ auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); \ @@ -186,10 +187,6 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { } \ } -#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ - DEFINE_TF_GETATTR_(TF_OpKernelConstruction, func, c_type, cc_type) \ - DEFINE_TF_GETATTR_(TF_OpKernelContext, func, c_type, cc_type) - DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index b015d0103969355e8566242bfcc007f697c6ae18..c47bfa8aa3a721d422a0a1536b924f3e53793193 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -137,15 +137,6 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, TF_Status* status); -// Interprets the named kernel context attribute as a TF_DataType and places it -// into *val. *status is set to TF_OK. -// -// If the attribute could not be found or could not be interpreted as -// TF_DataType, *status is populated with an error. -TF_CAPI_EXPORT extern void TF_OpKernelContext_GetAttrType( - TF_OpKernelContext* ctx, const char* attr_name, TF_DataType* val, - TF_Status* status); - #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 0d2954717e7a83c102a35815809a554e3a917e07..608887722f7bca44c884a3426d5e378e9387a530 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -36,6 +36,15 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { struct MyCustomKernel* s = new struct MyCustomKernel; s->created = true; s->compute_called = false; + + // Exercise attribute reads. + TF_DataType type; + TF_Status* status = TF_NewStatus(); + TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + EXPECT_EQ(TF_FLOAT, type); + TF_DeleteStatus(status); + return s; } @@ -43,17 +52,7 @@ static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { struct MyCustomKernel* s = static_cast(kernel); s->compute_called = true; if (ctx != nullptr) { - TF_Status* status = TF_NewStatus(); - EXPECT_EQ(43, TF_StepId(ctx)); - - // Exercise attribute reads. - TF_DataType type; - TF_OpKernelContext_GetAttrType(ctx, "SomeDataTypeAttr", &type, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - EXPECT_EQ(TF_FLOAT, type); - - TF_DeleteStatus(status); } } diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a09becc49b10d2c58f98fbcc11df5190f794c1d4..4c4d587fce04d101b3cc8faebcc3ba04f2f1d0cf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -150,6 +150,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ], ) @@ -586,6 +587,25 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) +tf_gen_op_wrappers_cc( + name = "tpu_ops", + include_internal_ops = 1, + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + cc_library_with_android_deps( name = "cc_op_gen_main", srcs = [ diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 39593370d1c243e84dc5b6091724d1d404c102b0..43a33cbea6e1e4a50f61cc7d6d8d70cac6a603d2 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -321,6 +321,7 @@ std::pair AttrTypeName(StringPiece attr_type) { {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, + {"list(func)", {"gtl::ArraySlice", true}}, }; auto entry = attr_type_map->find(attr_type); diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index affd90b1bcc7cb4a8b3ffed6aeeb4bd480f5e314..a7e645e8b556f14f0c7a51d2eba6ab1e2256b837 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -96,7 +96,7 @@ class SymbolicGradientBuilder { // Used to identify nodes at which to stop backprop. std::unordered_set GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes); + const std::unordered_set& output_nodes); const Scope& scope_; const ops::GradOpRegistry* registry_; @@ -167,7 +167,6 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, std::vector SymbolicGradientBuilder::GetReachableNodes() { std::vector reachable_nodes(scope_.graph()->num_node_ids(), false); std::deque queue; - std::vector visited(scope_.graph()->num_node_ids(), false); for (const Output& out : outputs_) { if (!reachable_nodes[out.node()->id()]) { queue.push_back(out.node()); @@ -180,10 +179,10 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { queue.pop_front(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; - if (visited[e->src()->id()]) continue; - queue.push_back(e->src()); - reachable_nodes[e->src()->id()] = true; - visited[e->src()->id()] = true; + if (!reachable_nodes[e->src()->id()]) { + queue.push_back(e->src()); + reachable_nodes[e->src()->id()] = true; + } } } return reachable_nodes; @@ -191,7 +190,7 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( const std::vector& reachable_nodes, - std::unordered_set output_nodes) { + const std::unordered_set& output_nodes) { // Output nodes that get transitively consumed by other `outputs_` are stored // in `internal_outputs`. std::unordered_set internal_outputs; diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 52345a376cc29ee47ccb9888c9bb26292468b5a9..dedd55f16afb879ea966dc89d14d88ee15d9e83e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -81,6 +81,7 @@ cc_library( ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", ]) + if_android([ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 10f7abf09e925c0c31cfd595ecee4605f189476f..66260fcf4a9b24f78d45010c6e86d4ee398b6d3d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" -#include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index 1833b6a65eef9baa2e92a13d9c4d44b79620de2f..2cf68c9cd8396987899b4f34f21b994b4722ead4 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -19,13 +19,19 @@ from __future__ import division as _division from __future__ import print_function as _print_function import os as _os +import sys as _sys # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # API IMPORTS PLACEHOLDER from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorboard.summary._tf.summary'), + error_msg=( + "Limited tf.compat.v2.summary API due to missing TensorBoard " + "installation")) _component_api_helper.package_hook( parent_package_str=__name__, child_package_str=( @@ -41,3 +47,10 @@ _component_api_helper.package_hook( # # This make this one symbol available directly. from tensorflow.python.compat.v2_compat import enable_v2_behavior # pylint: disable=g-import-not-at-top + +# Add module aliases +_current_module = _sys.modules[__name__] +if hasattr(_current_module, 'keras'): + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 16151e77737429f4fbf690fc34b12a70bacebdc4..af016bf80e7a10d8729a1eb385466af48b5810cd 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -30,6 +30,7 @@ cc_library( "flags.h", ], deps = [ + ":aot_only_var_handle_op", ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:cpu_function_runtime", @@ -71,6 +72,7 @@ tf_cc_test( ":tfcompile_lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", @@ -205,6 +207,15 @@ cc_library( ], ) +cc_library( + name = "aot_only_var_handle_op", + srcs = ["aot_only_var_handle_op.cc"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], + alwayslink = 1, +) + tf_cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], diff --git a/tensorflow/compiler/aot/aot_only_var_handle_op.cc b/tensorflow/compiler/aot/aot_only_var_handle_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ce36a979f424610a5aa952afa8db2245ed971a9 --- /dev/null +++ b/tensorflow/compiler/aot/aot_only_var_handle_op.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Implementation of varhandle that binds a VarHandleOp to an XlaResource of the +// same name. It is not safe to use this op in a JIT context. +class XlaAotOnlyVarHandleOp : public XlaOpKernel { + public: + explicit XlaAotOnlyVarHandleOp(OpKernelConstruction* c); + void Compile(XlaOpKernelContext* context) override; + + private: + string name_; +}; + +XlaAotOnlyVarHandleOp::XlaAotOnlyVarHandleOp(OpKernelConstruction* c) + : XlaOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("shared_name", &name_)); +} + +void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) { + // Look for a resource of the same name. TF also keys that on the container + // and type attributes, but that doesn't seem necessary. + for (const auto& resource : context->xla_context()->resources()) { + if (resource->kind() == XlaResource::kVariable && + resource->name() == name_) { + context->SetResourceOutput(0, resource.get()); + return; + } + } + context->SetStatus( + errors::InvalidArgument("Variable: ", name_, " not configured")); +} +} // namespace + +REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index d016632da2a9d7c2c2f81c02dd573787a0502923..da0598736a7d6b7f55458d76ca30fa6ad46a74f9 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -168,12 +168,12 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (config.feed_size() != num_args) { - return errors::InvalidArgument("mismatch between feed_size(", - config.feed_size(), ") and num_args(", - num_args, ")"); + if (config.feed_size() + config.variable_size() != num_args) { + return errors::InvalidArgument( + "mismatch between feed_size(", config.feed_size(), ")+variable_size(", + config.variable_size(), ") and num_args(", num_args, ")"); } - for (int i = 0; i < num_args; ++i) { + for (int i = 0; i < config.feed_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); @@ -212,12 +212,14 @@ Status GenResultMethods(const tf2xla::Config& config, // tuple result, and we rely on this to simplify code generation. return errors::Internal("codegen requires the XLA result to be a tuple"); } - if (config.fetch_size() != ps.result().tuple_shapes_size()) { + size_t num_results = ps.result().tuple_shapes_size(); + if (config.fetch_size() + config.variable_size() != num_results) { return errors::InvalidArgument("mismatch between fetch_size(", - config.feed_size(), ") and tuple_size(", + config.fetch_size(), ")+variable_size(", + config.variable_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { + for (int i = 0; i < config.fetch_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR(AddRewritesForShape( i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); @@ -245,6 +247,51 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generate methods for variables. +Status GenVariableMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, string* methods) { + size_t num_args = ps.parameters_size(); + for (int i = config.feed_size(); i < num_args; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); + const string code = R"( + void set_var_{{NAME}}_input_data({{TYPE}}* data) { + set_arg_data({{I}}, data); + } +)"; + const tf2xla::Variable& var = config.variable(i - config.feed_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + size_t num_results = ps.result().tuple_shapes_size(); + for (int i = config.fetch_size(); i < num_results; ++i) { + std::vector> rewrites; + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); + string code = R"( + {{TYPE}}* var_{{NAME}}_result_data() { + return static_cast<{{TYPE}}*>(result_data({{I}})); + } + {{TYPE}}& var_{{NAME}}_result({{DIM_VARS}}) { + return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( + result_data({{I}}))){{INDICES}}; + } + const {{TYPE}}* var_{{NAME}}_result_data() const { + return static_cast(result_data({{I}})); + } + const {{TYPE}}& var_{{NAME}}_result({{DIM_VARS}}) const { + return (*static_cast( + result_data({{I}}))){{INDICES}}; + } +)"; + const tf2xla::Variable& var = config.variable(i - config.fetch_size()); + *methods += RewriteWithName( + var.name().empty() ? var.node_name() : var.name(), code, rewrites); + } + return Status::OK(); +} + // Generates code implementing {Arg,Result}Names(), where T is one of // tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string // literal in the array, with nullptr terminating the array. @@ -291,6 +338,14 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name")); } } + for (const tf2xla::Variable& variable : config.variable()) { + if (!variable.name().empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name")); + } else { + TF_RETURN_IF_ERROR( + ValidateCppIdent(variable.node_name(), "variable name")); + } + } return Status::OK(); } @@ -339,9 +394,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); const xla::ProgramShapeProto& ps = compile_result.program_shape; - string methods_arg, methods_result; + string methods_arg, methods_result, methods_variable; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); + TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( buffer_infos_for_args.data(), buffer_infos_for_args.size(), /*allocate_entry_params=*/true); @@ -523,6 +579,21 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. +{{METHODS_VARIABLE}} + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; @@ -589,6 +660,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, + {"{{METHODS_VARIABLE}}\n", methods_variable}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index c1788ca32a1d099284eeb870f9513891051fd29e..5580e55b691bd10698b63d86bc0194b25da743b9 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -172,6 +174,15 @@ TEST(CodegenTest, Golden) { tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); + tf2xla::Variable* variable = config.add_variable(); + variable->set_node_name("myvar"); + variable->mutable_shape()->add_dim()->set_size(1); + variable->set_type(DT_FLOAT); + tf2xla::Variable* variable2 = config.add_variable(); + variable2->set_node_name("my/var"); + variable2->set_name("myvar2"); + variable2->mutable_shape()->add_dim()->set_size(5); + variable2->set_type(DT_INT32); CompileResult compile_result; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, @@ -186,9 +197,14 @@ TEST(CodegenTest, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(xla::U32, {5, 6}), + xla::ShapeUtil::MakeShape(xla::F32, {1}), + xla::ShapeUtil::MakeShape(xla::S32, {5}), + })) .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 35994fc785d3e1d5e883c49bec96de315e189d2e..b5f33d690d492489e9090786cd341e035ae7ca15 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -52,7 +52,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): f32[1], (unknown): s32[5]) -> (u32[5,6], f32[1], s32[5]) // // Memory stats: // arg bytes total: 104 @@ -214,6 +214,58 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { result_data(0)))[dim0][dim1]; } + // Methods for managing variable buffers. Buffers are in row-major order. The + // input and output buffers may or may not be identical. + // + // void set_var_X_data(T* data) + // Sets the buffer for variable X. + // + // T* var_X_data() + // Returns the buffer of type T for variable X. + // + // T& var_X(...dim indices...) + // Returns a reference to the value of type T for variable X, + // with dim indices specifying which value. No bounds checking is performed + // on dim indices. + + void set_var_myvar_input_data(float* data) { + set_arg_data(2, data); + } + + void set_var_myvar2_input_data(tensorflow::int32* data) { + set_arg_data(3, data); + } + + float* var_myvar_result_data() { + return static_cast(result_data(1)); + } + float& var_myvar_result() { + return (*static_cast( + result_data(1)))[0]; + } + const float* var_myvar_result_data() const { + return static_cast(result_data(1)); + } + const float& var_myvar_result() const { + return (*static_cast( + result_data(1)))[0]; + } + + tensorflow::int32* var_myvar2_result_data() { + return static_cast(result_data(2)); + } + tensorflow::int32& var_myvar2_result(size_t dim0) { + return (*static_cast( + result_data(2)))[dim0]; + } + const tensorflow::int32* var_myvar2_result_data() const { + return static_cast(result_data(2)); + } + const tensorflow::int32& var_myvar2_result(size_t dim0) const { + return (*static_cast( + result_data(2)))[dim0]; + } + private: // Number of buffers for the compiled computation. static constexpr size_t kNumBuffers = 6; @@ -257,7 +309,7 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShapeProto* StaticProgramShape() { static const xla::ProgramShapeProto* kShape = []() { xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 64); + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 132); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index 7f7b96428572705f30144e6c95cd4cf9c44ce2a3..2884597abcf29583e6192296b0e4ce6825d7c01a 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 9fc223bdc7c0e207ce2005cb86250aa77e709df8..0e46a9f5e9d68fa2174f7bd9b9fa7c3a82dfb715 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -108,10 +108,13 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, computation.Snapshot()); // Serialize the HloSnapshot deterministically so that all the outputs of a // tf_library genrule are deterministic. - string proto; - TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); + const size_t size = module->ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(*module, serialized.get(), size)); TF_RETURN_IF_ERROR( - WriteStringToFile(Env::Default(), flags.out_session_module, proto)); + WriteStringToFile(Env::Default(), flags.out_session_module, + absl::string_view(serialized.get(), size))); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 10fa33ab5e84dcbc1629bee6214e8969046f19c2..444264ba6e1f59c33551796025ba845c62c02d43 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -69,6 +69,7 @@ genrule( "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", "test_graph_tftop_k.pb", + "test_graph_tfvariable.pb", ], # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # GPUs which might be present. This is important because builds may run @@ -222,6 +223,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfvariable", + testonly = 1, + config = "test_graph_tfvariable.config.pbtxt", + cpp_class = "VariableComp", + graph = "test_graph_tfvariable.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -241,6 +253,7 @@ tf_cc_test( ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", ":test_graph_tftop_k", + ":test_graph_tfvariable", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 7bac79ec062af7e790134286e34eda4e123e138a..42f8812def0503824416d92daa2db71a64c3db88 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -149,6 +149,14 @@ def tftop_k(_): array_ops.identity(output[1], name='indices') +def tfvariable(_): + x = variables.Variable(1000.0, name='x') + old_x = x.value() + with ops.control_dependencies([old_x]): + new_x = x.assign_add(42.0) + array_ops.stack([old_x, new_x], name='result') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -171,6 +179,7 @@ def main(_): write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) write_graph(tftop_k, FLAGS.out_dir) + write_graph(tfvariable, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b4c4215a330b014f595edde001aba73ad7d8263 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfvariable.config.pbtxt @@ -0,0 +1,12 @@ +# Text form of tensorflow.tf2xla.Config proto. +fetch { + id { node_name: "result" } +} + +variable { + node_name: "x" + shape { + dim { size: 1 } + } + type: DT_FLOAT +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 4dd79e5882d7da61be029735ef2b165908c599f9..5f9316f3933713e12fc5960b9adfecc6e9bd99b5 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -473,6 +474,28 @@ TEST(TFCompileTest, TopK) { EXPECT_EQ(expected_indices[1], fn.result1(1)); } +TEST(TFCompileTest, Variable) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + VariableComp fn; + float x = 23; + fn.set_var_x_input_data(&x); + + fn.set_thread_pool(&device); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 23); + EXPECT_EQ(fn.result0(1, 0), 65); + EXPECT_EQ(fn.var_x_result(), 65); + + EXPECT_EQ(x, 23); + x = fn.var_x_result(); + fn.Run(); + EXPECT_EQ(fn.result0(0, 0), 65); + EXPECT_EQ(fn.result0(1, 0), 107); + EXPECT_EQ(fn.var_x_result(), 107); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 3cae081ce7c78226390a82d222d57ac653c14321..121de401cefb2b56b984944dde769f226590dc67 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -208,6 +208,7 @@ cc_library( "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", @@ -282,7 +283,6 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -465,7 +465,6 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 9f4042630edaec1b9519b6434d859a48372e8b15..285b1efa53d91922c9fa161cfd2de34e1434d0c4 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -115,6 +115,13 @@ void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { return; } + if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) { + // Avoid creating a Merge node if we can just add an edge to _SINK + // instead. + s.graph()->AddControlEdge(new_node, s.graph()->sink_node()); + return; + } + // We can't merge control edges directly so we instead first "convert" them to // normal values that can be merged, merge the values and then "convert" the // merged value back into control. diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 390ffa694b6f127544d92f3024a02d877556aacd..c14c7465c55b7d350d6b3a6853cef6692140ce78 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -68,6 +68,8 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { } } + FixupSourceAndSinkEdges(graph.get()); + GraphOptimizationPassOptions opt_options; opt_options.graph = &graph; BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); @@ -223,5 +225,23 @@ TEST_F(BuildXlaOpsTest, OnXlaDevice) { ASSERT_NE(write_op_new, nullptr); EXPECT_THAT(write_op_new, assign_var); } + +TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* sink_node = graph->sink_node(); + EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")), + NodeWith(Op("cluster_0")), + NodeWith(Op("NoOp"))))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0ef0d3db8c16e4b3f78d29aad5a2ae75a81d96f6..4397eea9af266cbd0392f08323e59077c9395150 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -113,7 +113,11 @@ class Predicate { enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; - int64 hash() const { return hash_; } + + // An ID assigned to the Predicate at construction time. Conceptually like a + // pointer, except that it is stable across runs. + int64 id() const { return id_; } + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; @@ -126,29 +130,19 @@ class Predicate { static void Visit(Predicate* p, const FunctionTy& func); protected: - explicit Predicate(int64 hash) : hash_(hash) {} + explicit Predicate(int64 id) : id_(id) {} private: - const int64 hash_; + const int64 id_; TF_DISALLOW_COPY_AND_ASSIGN(Predicate); }; -int64 HashPredicateSequence(Predicate::Kind kind, - absl::Span preds) { - int64 hash = ::tensorflow::hash()(kind); - for (Predicate* pred : preds) { - hash = Hash64Combine(hash, pred->hash()); - } - return hash; -} - // Represents a logical conjunction of a set of predicates. class AndPredicate : public Predicate { public: - explicit AndPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kAnd, operands)), - operands_(std::move(operands)) {} + explicit AndPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -177,9 +171,8 @@ class AndPredicate : public Predicate { // Represents a logical disjunction of a set of predicates. class OrPredicate : public Predicate { public: - explicit OrPredicate(std::vector operands) - : Predicate(HashPredicateSequence(Kind::kOr, operands)), - operands_(std::move(operands)) {} + explicit OrPredicate(int64 id, std::vector operands) + : Predicate(id), operands_(std::move(operands)) {} string ToString() const override { if (operands().empty()) { @@ -207,9 +200,8 @@ class OrPredicate : public Predicate { // Represents a logical negation of a set of predicates. class NotPredicate : public Predicate { public: - explicit NotPredicate(Predicate* operand) - : Predicate(HashPredicateSequence(Kind::kNot, {operand})), - operands_({operand}) {} + explicit NotPredicate(int64 id, Predicate* operand) + : Predicate(id), operands_({operand}) {} string ToString() const override { return absl::StrCat("~", operand()->ToString()); @@ -246,11 +238,9 @@ class NotPredicate : public Predicate { // iterations). class AndRecurrencePredicate : public Predicate { public: - explicit AndRecurrencePredicate(Predicate* start, Predicate* step, + explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step, std::vector frame) - : Predicate(Hash(start, step, frame)), - operands_({start, step}), - frame_(std::move(frame)) {} + : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } @@ -270,16 +260,6 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; std::vector frame_; - - static int64 Hash(Predicate* start, Predicate* step, - const std::vector& frame) { - uint64 frame_hash = 0; - for (const string& sub_frame : frame) { - frame_hash = Hash64Combine(Hash64(sub_frame), frame_hash); - } - return Hash64Combine( - HashPredicateSequence(Kind::kAndRecurrence, {start, step}), frame_hash); - } }; // Represents an uninterpreted symbol in a logical predicate. @@ -289,8 +269,8 @@ class AndRecurrencePredicate : public Predicate { // symbols. class SymbolPredicate : public Predicate { public: - explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) - : Predicate(Hash(tensor_id, must_be_true)), + explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true) + : Predicate(id), tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} @@ -313,13 +293,6 @@ class SymbolPredicate : public Predicate { private: TensorId tensor_id_; bool must_be_true_; - - static int64 Hash(const TensorId tensor_id, bool must_be_true) { - return Hash64Combine( - ::tensorflow::hash()(must_be_true), - Hash64Combine(::tensorflow::hash()(Kind::kSymbol), - TensorId::Hasher{}(tensor_id))); - } }; template @@ -477,8 +450,11 @@ class PredicateFactory { template std::unique_ptr Make(Args&&... args) { + // If we ever expose the Predicate class outside this .cc file then we may + // want to make this hard to misuse (by accidentally passing in an arbitrary + // integer to the Predicate constructor for instance). return std::unique_ptr( - new PredicateT(std::forward(args)...)); + new PredicateT(id_counter_++, std::forward(args)...)); } Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); @@ -559,6 +535,7 @@ class PredicateFactory { absl::flat_hash_map, HashSignatureForSymbol> interned_symbol_instances_; + int64 id_counter_ = 0; int stack_depth_ = 0; }; @@ -566,7 +543,7 @@ Predicate* PredicateFactory::MakeInternedAndOr( std::vector simplified_ops, Predicate::Kind pred_kind) { std::stable_sort( simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + [](Predicate* a, Predicate* b) { return a->id() < b->id(); }); auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); if (it != interned_and_or_instances_.end()) { diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 16ee8f86d55c72785368ac2fd67635eba2fa7cd7..38a5118d9a721b814e1b52ce4202d4fb783e3ac3 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -521,7 +521,7 @@ TEST(DeadnessAnalysisTest, Loop) { EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + "({#true,&,*iv0/cond:0} & {#true,&,*iv1/cond:0})"); EXPECT_EQ(predicate_map[ControlOutputFor(add1)], "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); } @@ -553,11 +553,11 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + "{#true,&,(iv0/iv:0 & *iv0/cond:0)}"); } } @@ -643,22 +643,23 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], - "{({#true,&,*iv_outer/cond:0} & " - "*iv_outer/cond:0),&,*iv_inner/cond:0}"); + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " - "iv_inner/iv:0)}"); + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " - "iv_inner/iv:0)}"); + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " - "iv_inner/iv:0)}"); + "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " + "*iv_inner/cond:0)}"); } } @@ -702,20 +703,21 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])], "{#true,&,*iv_outer/cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])], - "{({#true,&,*iv_outer/cond:0} & " - "*iv_outer/cond:0),&,*iv_inner/cond:0}"); + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])], "{#true,&,*iv_outer/cond_1:0}"); - EXPECT_EQ( - predicate_map[ControlOutputFor(inner_iv[1])], - "{({#true,&,*iv_outer/cond_1:0} & " - "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0}"); - EXPECT_EQ( - predicate_map[ControlOutputFor(add0)], - "({({#true,&,*iv_outer/cond:0} & " - "*iv_outer/cond:0),&,*iv_inner/cond:0} & " - "{({#true,&,*iv_outer/cond_1:0} & " - "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])], + "{(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0} & {(*iv_outer/cond_1:0 & " + "{#true,&,*iv_outer/cond_1:0}),&,*iv_inner/" + "cond_1:0})"); } } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1f8ec09e19c01d0a8b2a3761135ed53dfb2ad3b0..261519de3478c8b3e30d206a15944b5a686598e2 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -307,22 +307,6 @@ REGISTER_OP("XlaHostCompute") .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); -REGISTER_OP("_XlaSendFromHost") - .Input("inputs: Tinputs") - .Input("dynamic_key: string") - .Attr("Tinputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - -REGISTER_OP("_XlaRecvAtHost") - .Input("dynamic_key: string") - .Output("outputs: Toutputs") - .Attr("Toutputs: list(type) >= 0") - .Attr("key: string") - .Attr("device_ordinal: int") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 3bb979e0698d2d6be42ed5bae66c25267928192c..6d1661222e3eaf9df4f9f91f2b426c80b55245b2 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index ec745cdbb7e237f8b4935dd41e9791fc75f5355d..f0c9d573451952a398dce190e102a33270a4d739 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,17 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -36,6 +40,25 @@ namespace { const char* const kXlaClusterOutput = "XlaClusterOutput"; +bool IsCpuGpuCompile(const Graph* graph) { + for (Node* n : graph->nodes()) { + string name; + // Only consider nodes being compiled. + if (!GetNodeAttr(n->attrs(), + EncapsulateXlaComputationsPass::kXlaClusterAttr, &name) + .ok()) + continue; + // Early return for any node with a device that is not a CPU or GPU. + DeviceNameUtils::ParsedName parsed; + if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { + if (parsed.type != DEVICE_CPU && parsed.type != DEVICE_GPU) { + return false; + } + } + } + return true; +} + // Checks if a graph node is marked to be a guaranteed constant. bool is_guaranteed_constant(const Node& n) { bool guaranteed_constant = false; @@ -173,10 +196,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // Nondeterminism in serialization would not lead to incorrect results, but // may cause spurious cache misses. DeterministicSerialization is a // best-effort deterministic serialization. - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; + const size_t size = gdef.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); + VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return Status::OK(); } @@ -351,12 +375,19 @@ Status EncapsulateXlaComputationsPass::Run( << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", **options.graph, options.flib_def); - TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); + const char* additional_help = + IsCpuGpuCompile(options.graph->get()) + ? xla::status_macros::kPossibleAutoJitAlternative + : ""; + + TF_RETURN_WITH_CONTEXT_IF_ERROR(Encapsulate(options.graph, options.flib_def), + additional_help); VLOG(1) << "EncapsulateXlaComputations() half-way: " << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", **options.graph, options.flib_def); - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); + TF_RETURN_WITH_CONTEXT_IF_ERROR(BuildXlaLaunchOps(options.graph->get()), + additional_help); VLOG(1) << "EncapsulateXlaComputations() finished: " << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", **options.graph, options.flib_def); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index ebfffc3267e5acdf593bea2517c447083133e39c..5287fd175df206970b9fa73bc6b0176eddcdcaa9 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -247,6 +247,7 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); Scope host_scope = main_scope.WithAssignedDevice(host_name); + // In the future we may want to be clever here and avoid the extra Cast ops. SliceInputs slice_inputs_int64 = MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); @@ -312,9 +313,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Return true if `n` is a slice we can rewrite to have a static shape +// Return true if `n` is a slice we should rewrite to have a static shape // (i.e. have the output shape only depend on the "size" input). -xla::StatusOr IsRewritableSlice(Node* n) { +xla::StatusOr ShouldRewriteSlice(Node* n) { if (n->type_string() != "Slice") { return false; } @@ -332,14 +333,20 @@ xla::StatusOr IsRewritableSlice(Node* n) { // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - return absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); + bool slice_size_has_error = absl::c_all_of( + slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; }); + if (!slice_size_has_error) { + return false; + } + + // No point in rewriting slices that have both size and begin as constants. + return !slice_inputs->begin.node()->IsConstant(); } Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n)); if (is_rewritable) { slices_to_rewrite.push_back(n); } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 32e30216de565b4c1918903bf6c70c321c38cbb3..2add2c13f92f561904163012ee16cc17ce5badce 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -432,5 +432,26 @@ TEST(SliceToDynamicSliceRewriteTest, WithControlDepsToConstant) { Name("dependency"))))); } +TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Const(root.WithOpName("begin"), {10, 10}); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* slice_node = testing::FindNodeByName(result.get(), "slice"); + EXPECT_THAT(slice_node, + NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))), + Out(NodeWith(Op("Const"))), + Out(NodeWith(Op("Const")))))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index bab824c15f8f27f5325e79cd92d50cdaad850233..d0fa2c40be9d6b13ec736a9d6483dae0b4f0f45e 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -19,6 +19,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index ad71df5a694a5f8da94675049df1062a7edb6253..997ef6e14bb9bd16ddac13eaf67368966818b29e 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" @@ -35,6 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" @@ -304,10 +307,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - OP_REQUIRES_OK( - ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/false, &client, - &variables, &kernel, &executable)); + { + Status s = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false, + &client, &variables, &kernel, &executable); + if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU || + platform_info_.device_type().type_string() == DEVICE_GPU)) { + // Suggest auto jit if the failure was with GPU or CPU. + errors::AppendToMessage(&s, + xla::status_macros::kPossibleAutoJitAlternative); + } + + OP_REQUIRES_OK(ctx, s); + } se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 20c2cd7e0561f92a01486102c4d2c572fd80c957..d9a83049d6352f04f9237f21b44bdb5ea18e518a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1176,6 +1176,8 @@ Status MarkForCompilationPass::RunImpl( if (absl::optional cluster_name = GetXlaClusterForNode(*n)) { n->set_name(absl::StrCat(*cluster_name, "/", n->name())); + } else if (n->type_string() == "VarHandleOp") { + n->set_name(absl::StrCat("varhandle/", n->name())); } else { // There is room for improvement here. In particular, it may help to // split these unclustered nodes into classes where every node in a diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 80993861abba050fa3d6a133023d3c99f41f73e3..3adcfef4dacecb343812cefc3a893a65c74ca101 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -43,7 +43,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, return ""; } - auto node_name = [cycles, &graph](int node_id) { + auto node_name = [&graph](int node_id) { if (!FastBoundsCheck(node_id, graph.num_node_ids())) { return string("(null)"); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c67b4f11b030f22c123336327ff9fa67b1211d7a..56c4220f12b54be09821eca4590df52e8e71850b 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -102,7 +102,8 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - absl::make_unique(); + absl::make_unique( + backend->stream_executors()[device_ordinal]); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -428,7 +429,7 @@ void XlaDevice::Sync(const DoneCallback& done) { // moment--when ThenEnqueueOnBackgroundThread is called--will have finished. // This achieves a device-wide sync. stream->ThenEnqueueOnBackgroundThread( - [this, stream, done](se::StreamExecutor*) { + [stream, done](se::StreamExecutor*) { tracing::ScopedActivity activity("XlaDevice::Sync::Callback", /*is_expensive=*/true); done(stream->ok() ? Status::OK() @@ -479,6 +480,23 @@ bool XlaDevice::AllowsSyncOnCompletion() const { return sync_on_completion_; } +void XlaDevice::SetHandleDeviceErrorCallback(std::function callback) { + mutex_lock lock(mu_); + device_error_callback_ = callback; +} + +Status XlaDevice::HandleDeviceError() { + std::function local_device_error_callback; + { + mutex_lock lock(mu_); + local_device_error_callback = device_error_callback_; + } + if (local_device_error_callback != nullptr) { + return local_device_error_callback(); + } + return Status::OK(); +} + Status XlaDevice::RefreshStatus() { std::shared_ptr stream; { @@ -488,8 +506,14 @@ Status XlaDevice::RefreshStatus() { if (!stream) { return Status::OK(); } - // Stream status is XlaDevice status, no extra operations needed. - return stream->RefreshStatus(); + Status status = stream->RefreshStatus(); + if (!status.ok()) { + // Ignore errors from HandleDeviceError, since by definition the status is + // already non-ok, so there's nothing extra to report if HandleDeviceError + // itself returns an error. + HandleDeviceError().IgnoreError(); + } + return status; } XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 5fe1290fa03f2b1f9d90e36dbc5769b3c2728c8d..977f5f5cf151d979d025c2966012445af04fc502 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -171,6 +171,9 @@ class XlaDevice : public LocalDevice { void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // Installs an error handling callback when RefreshStatus sees !status.ok(). + void SetHandleDeviceErrorCallback(std::function callback); + Status RefreshStatus() override LOCKS_EXCLUDED(mu_); private: @@ -187,6 +190,9 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); + // Handles error when RefreshStatus sees !status.ok(). + Status HandleDeviceError(); + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; @@ -237,6 +243,9 @@ class XlaDevice : public LocalDevice { // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = true; + // A callback that will be invoked when RefreshStatus sees a status error. + std::function device_error_callback_ GUARDED_BY(mu_); + // Set of devices to use. This controls which of the devices on the given // platform will have resources allocated. For GPUs this will be // filled from visible_gpu_devices list from session configuration. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 28681bb8b03dbf97e8145972f9a04b5855fafdae..05b9c511866d3ca48ec3519bee8a4dbf6086f6ac 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -29,7 +29,10 @@ limitations under the License. namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator() {} +XlaDeviceAllocator::XlaDeviceAllocator( + stream_executor::StreamExecutor* stream_executor) + : stream_executor_(stream_executor) {} + XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } @@ -48,7 +51,21 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { delete XlaTensor::FromOpaquePointer(ptr); } -void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } +absl::optional XlaDeviceAllocator::GetStats() { + absl::optional se_stats = + stream_executor_->GetAllocatorStats(); + if (!se_stats) { + return absl::nullopt; + } + + tensorflow::AllocatorStats tf_stats; + tf_stats.num_allocs = se_stats->num_allocs; + tf_stats.bytes_in_use = se_stats->bytes_in_use; + tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use; + tf_stats.largest_alloc_size = se_stats->largest_alloc_size; + tf_stats.bytes_limit = se_stats->bytes_limit; + return tf_stats; +} XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index e45db989fac720df6c3458c93a6b8dbb0919f930..1ce64ad323b4827adc2f4d48841315fbde43e532 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -34,14 +34,18 @@ namespace tensorflow { // empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(); + XlaDeviceAllocator(se::StreamExecutor* stream_executor); ~XlaDeviceAllocator() override; string Name() override; void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void GetStats(AllocatorStats* stats) override; + absl::optional GetStats() override; + + private: + // The stream executor of the device. + se::StreamExecutor* stream_executor_; }; // Helper class for managing data transfers between host and XLA devices. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 927f983ba9ef23c8509523f42366c0c89c29db9f..09e04d22def9c39f45c2737c1d4a5e7787e3fdc0 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/data/generator_dataset_op.h" #include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" @@ -241,6 +242,8 @@ class XlaAssignVariableOp : public OpKernel { data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ + data::IteratorGetNextAsOptionalOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ @@ -251,6 +254,15 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .HostMemory("string_handle"), \ data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ + data::OptionalNoneOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ + data::OptionalFromValueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ + data::OptionalHasValueOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ + data::OptionalGetValueOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 9b6ca4092c3177ac26503add13bce25d2c0bb820..7c1e0daf0b7b418530367cb80fbd18b93e8e5f5e 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -250,6 +250,29 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "self_adjoint_eig_op_test", + size = "medium", + srcs = ["self_adjoint_eig_op_test.py"], + # TODO(kuny): remove it after b/124377352 is fixed. + disabled_backends = [ + "cpu", + "gpu", + "cpu_ondemand", + ], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:map_fn", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5d5e486f616937601214aa169a4c329ab78932c8..eec69ea7d2d9af9ff570f927fb25b668ccce2b97 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -119,7 +119,7 @@ class CategoricalTest(xla_test.XLATestCase): def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. - num_samples = 21000 + num_samples = 40000 rand_probs = np.random.dirichlet([1., 1., 2., 3.]) rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched diff --git a/tensorflow/compiler/tests/plugin.bzl b/tensorflow/compiler/tests/plugin.bzl index fbc8781a3e59faecf985cde5114bf56a041c4be0..46a854d1459b7ea9d9fe3cf7689faee557c2cf84 100644 --- a/tensorflow/compiler/tests/plugin.bzl +++ b/tensorflow/compiler/tests/plugin.bzl @@ -18,13 +18,12 @@ # git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl plugins = { - #"example": { - # "device":"XLA_MY_DEVICE", - # "types":"DT_FLOAT,DT_HALF,DT_INT32", - # "tags":[], - # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], - # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], - # "deps":[], - #}, + #"example": { + # "device":"XLA_MY_DEVICE", + # "types":"DT_FLOAT,DT_HALF,DT_INT32", + # "tags":[], + # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], + # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], + # "deps":[], + #}, } - diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 97ffad34c00b8ec16eb1ec109ba5d980e0ce673d..34f2465ba63f235f893db9dd6930ac252c3e7226 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -122,8 +122,8 @@ class RandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == count) - self.assertTrue((y <= b).sum() == count) + self.assertEqual((y >= a).sum(), count) + self.assertEqual((y <= b).sum(), count) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 693f8513bc54e30060a2e963abd504768535a50a..a9a87b8fb3104f8b9870c41e2aa28b0c48c12921 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -134,6 +134,12 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + def testRepeatedIndices(self): + indices = np.array([[0], [1], [0], [1]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([20, 22], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2])) + def testSimple2(self): indices = np.array([[1, 0], [1, 1]], dtype=np.int32) updates = np.array([11., 12.], dtype=np.float32) diff --git a/tensorflow/compiler/tests/self_adjoint_eig_op_test.py b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb5c82b22ea1d7400b54045edee0ca0782ce979 --- /dev/null +++ b/tensorflow/compiler/tests/self_adjoint_eig_op_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.self_adjoint_eig.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SelfAdjointEigOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def _test(self, dtype, shape): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + x_np = x_np + np.swapaxes(x_np, -1, -2) + n = shape[-1] + + e_np, _ = np.linalg.eigh(x_np) + with self.cached_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + e, v = linalg_ops.self_adjoint_eig(x_tf) + e_val, v_val = sess.run([e, v], feed_dict={x_tf: x_np}) + + v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) + self.assertAlmostEqual(np.mean(v_diff**2), 0.0, delta=1e-6) + self.assertAlmostEqual(np.mean((e_val - e_np)**2), 0.0, delta=1e-6) + + SIZES = [1, 2, 5, 10, 32] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testSelfAdjointEig(self, n, dtype): + for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): + self._test(dtype, batch_dims + (n, n)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index ee7ca7e6f196e114ff18e2597145e5c198980b08..df5914a518e06e4190c623a14287de8daefebd40 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -167,8 +167,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): beta = (b - mu) / sigma z = normal_cdf(beta) - normal_cdf(alpha) - self.assertTrue((y >= a).sum() == n) - self.assertTrue((y <= b).sum() == n) + self.assertEqual((y >= a).sum(), n) + self.assertEqual((y <= b).sum(), n) # For more information on these calculations, see: # Burkardt, John. "The Truncated Normal Distribution". diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 47e0f384a4f1e46ccc35584aaff3a0aceff8a985..a380715301b08ce2186c97b678b7235b9121d178 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -102,7 +102,7 @@ class ListOpsTest(xla_test.XLATestCase): _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Set the max number of elements"): - self.assertEqual(sess.run(e), 1.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15))) def testEmptyTensorListMax(self): with self.cached_session() as sess, self.test_scope(): @@ -136,6 +136,17 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 2.0]) + def testSetDoesNotUpdatePushIndex(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_shape=[], element_dtype=dtypes.float32, max_num_elements=2) + # SetItem should not change the push index. + l = list_ops.tensor_list_set_item(l, 1, 3.) + l = list_ops.tensor_list_push_back(l, 5.) + l = list_ops.tensor_list_push_back(l, 7.) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [5., 7.]) + def testGetSetReserved(self): with self.cached_session(), self.test_scope(): l = list_ops.tensor_list_reserve( @@ -146,6 +157,25 @@ class ListOpsTest(xla_test.XLATestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(t, [3.0, 0.0]) + def testSetStackReservedUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.tensor_list_reserve( + element_dtype=dtypes.float32, element_shape=None, num_elements=2) + l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) + self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]]) + + def testPushInEmptyListWithUnknownElementShape(self): + with self.cached_session(), self.test_scope(): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=None, max_num_elements=2) + l = list_ops.tensor_list_push_back(l, [3.0, 4.0]) + # Pushing an element with a different shape should raise an error. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "Shape"): + l = list_ops.tensor_list_push_back(l, 5.) + self.evaluate( + list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)) + def testGetSetReservedNonScalar(self): with self.cached_session() as sess, self.test_scope(): l = list_ops.tensor_list_reserve( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 083e2e58ae02b1aa383da76aebfca60fac59b84b..f2e0eac2d99fe3b71ecabd4b9977817c5f9c372c 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -72,6 +72,7 @@ class UnaryOpsTest(xla_test.XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: @@ -260,7 +261,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], + dtype=dtype)).astype(dtype), rtol=1e-4, atol=1e-6) @@ -391,6 +393,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype), + expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.is_finite, np.array( @@ -705,7 +712,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=dtype)) + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) self._assertOpOutputMatchesExpected( math_ops.negative, @@ -743,6 +750,10 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array( [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.sign, + np.array([[np.nan]], dtype=dtype), + expected=np.array([[0.0]], dtype=dtype)) def testLogicalOps(self): self._assertOpOutputMatchesExpected( @@ -811,6 +822,12 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) + def testInvertPermutationTwiceIsNoop(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), + np.array([1, 2, 0], np.int32), + expected=np.array([1, 2, 0], dtype=np.int32)) + def testRank(self): rank_op = lambda x: array_ops.rank_internal(x, optimize=False) for dtype in self.numeric_types: @@ -865,6 +882,17 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(3)) + def testSizeWithInt64OutType(self): + + def size_op(x): + return array_ops.size_internal(x, optimize=False, out_type=np.int64) + + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + size_op, + np.array([[-1], [1], [4]], dtype=dtype), + expected=np.int64(3)) + def testUnpack(self): self._assertOpOutputMatchesExpected( array_ops.unstack, @@ -974,7 +1002,7 @@ class UnaryOpsTest(xla_test.XLATestCase): def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) - expected = np.logaddexp(zero, features) + expected = np.logaddexp(zero, features).astype(dtype) self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 00d3c8cc5f610ea5c308fa7df49d963c78919d63..63cad6a159c3a9b0da9e3bb86ff250dd29e45729 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -160,6 +160,7 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", "//tensorflow/python:resources", @@ -170,13 +171,11 @@ tf_cuda_library( name = "trt_resources", srcs = [ "utils/trt_int8_calibrator.cc", - "utils/trt_resource_manager.cc", "utils/trt_resources.cc", ], hdrs = [ "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", - "utils/trt_resource_manager.h", "utils/trt_resources.h", ], deps = [ @@ -265,7 +264,6 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -351,6 +349,7 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], + copts = tf_copts(), deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -360,11 +359,12 @@ cc_library( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "segment_test", size = "small", srcs = ["segment/segment_test.cc"], tags = [ + "no_cuda_on_cpu_tap", "no_windows", "nomac", ], @@ -430,7 +430,7 @@ cc_library( copts = tf_copts(), deps = [ "//tensorflow/core:framework", - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", ], ) @@ -439,7 +439,7 @@ cc_library( srcs = ["utils/test_utils.cc"], hdrs = ["utils/test_utils.h"], deps = [ - "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", "@com_googlesource_code_re2//:re2", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index f3db42509ecf1d5176c8f56ef13d2c76d038ee7a..1f3cae3fda0cd7be296882b7b17ea47554edace8 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h" #include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" @@ -90,29 +89,40 @@ TrtCandidateSelector::TrtCandidateSelector( Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange - static const std::set candidate_ops = { + static const auto* candidate_ops = new std::set{ "Abs", + "Acos", + "Acosh", "Add", + "Asin", + "Asinh", + "Atan", + "Atanh", "AvgPool", "BatchMatMul", "BiasAdd", + "Ceil", "ConcatV2", "Const", "Conv2D", "Conv2DBackpropInput", + "Cos", + "Cosh", "DepthwiseConv2dNative", "Div", "Exp", "ExpandDims", + "Floor", "FusedBatchNorm", "FusedBatchNormV2", + "GatherV2", "Identity", "LeakyRelu", "Log", "MatMul", "Max", - "MaxPool", "Maximum", + "MaxPool", "Mean", "Min", "Minimum", @@ -126,8 +136,10 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "Relu6", "Reshape", "Rsqrt", - "Rsqrt", "Sigmoid", + "Sin", + "Sinh", + "Slice", "Snapshot", "Softmax", "Sqrt", @@ -136,14 +148,15 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { "StridedSlice", "Sub", "Sum", + "Tan", "Tanh", "TopKV2", "Transpose", }; bool is_supported_op_type = - (candidate_ops.count(node->type_string()) || + (candidate_ops->count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); - static const std::set quantize_ops = { + static const auto* quantize_ops = new std::set{ "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", "FakeQuantWithMinMaxVars", @@ -153,7 +166,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // these ops to the relevant tensors. This happens regardless of the value of // use_calibration. if (precision_mode_ == TrtPrecisionMode::INT8 && - quantize_ops.count(node->type_string())) { + quantize_ops->count(node->type_string())) { is_supported_op_type = true; } // LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc) @@ -190,55 +203,6 @@ tensorflow::Status BuildNodeMap( } // namespace -// Function to get calibration from ResourceMgr and put them into nodedef. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, - bool is_dyn_op) { - LOG(INFO) << "Starting Calib Conversion"; - infer_graph->CopyFrom(graph_def); - auto trt_rm = TRTResourceManager::instance(); - auto calib_rm = trt_rm->getManager("TRTCalibration"); - int num_nodes = infer_graph->node_size(); - if (!is_dyn_op) { - LOG(WARNING) << "Construction of static int8 engine is not implemented " - "yet!. Dynamic engine will be constructed"; - } - for (int i = 0; i < num_nodes; ++i) { - auto n = infer_graph->mutable_node(i); - if (n->op() == "TRTEngineOp") { - VLOG(1) << "Processing " << n->name(); - const string& container_name = n->attr().at("segment_funcdef_name").s(); - TRTCalibrationResource* cres = nullptr; - auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); - if (!status.ok()) { - LOG(ERROR) << "Could not get Calibration information. Did you run with " - "calibration data?"; - return tensorflow::errors::FailedPrecondition( - "Need to run graph with calibration data first!"); - } - tensorflow::core::ScopedUnref calib_sc(cres); - if (cres->calibrator_) { - cres->calibrator_->waitAndSetDone(); - cres->thr_->join(); - const auto& calibration_table = - cres->calibrator_->getCalibrationTableAsString(); - if (!calibration_table.size()) { - LOG(ERROR) << "Calibration table is empty"; - return tensorflow::errors::Unknown( - "Calibration table is missing. This shouldn't have happened!"); - } - n->mutable_attr()->at("calibration_data").set_s(calibration_table); - } else { - LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; - return tensorflow::errors::Unknown( - "Can't get TRTCalibrator from resource manager!"); - } - TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); - } - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, @@ -662,8 +626,8 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, info.use_calibration, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); - segment_string = - string((const char*)engine_data->data(), engine_data->size()); + segment_string = string(static_cast(engine_data->data()), + engine_data->size()); if (calibrate_int8) { // See above comment about why not putting this inside the 'else' branch. segment_string = info.segment_graph_def.SerializeAsString(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 95cf0227dcf84396b9de52194ae3a750f4acca66..80f68d36a3ab894e97586687ee9ab93dddc73c50 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -85,12 +85,6 @@ struct ConversionParams { std::vector cached_engine_batches; // list of cached engines }; -// This method extracts calibration information from the resource managers -// and puts them in to engine nodedefs. -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, - bool is_dyn_op); - // - max_batch_size: maximum batch size which can be used for inference for // optimization targets inference run with max batch size. // - max_workspace_size_bytes: The upper bound of memory allowance for engine diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index cabc6ccfa13df77b3bd26f51f35284816141423a..1a754181debf41865190aa7f9ca6a76efea98181 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -75,7 +75,7 @@ TEST(TrtCandidateSelector, Basics) { feed, const_1, matmul_attrs); // Unsupported op. - auto unsupported_op = ops::Sin(s.WithOpName("sin"), feed); + auto unsupported_op = ops::Erf(s.WithOpName("sin"), feed); // Incompatible input. auto incompatible_feed = ops::Placeholder(s.WithOpName("feed"), DT_DOUBLE); @@ -108,7 +108,7 @@ TEST(TrtCandidateSelector, Basics) { "transpose_a is not supported for TensorRT FullyConnected " "(op: MatMul), at: incompatible_matmul"); ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); + error::UNIMPLEMENTED, "Op type Erf is not supported"); ExpectStatus( selector.IsTensorRTCandidate( matmul_with_incompatible_input.operation.node()), diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 79b1cba32909c119a9127c3d254a6f14a16cb660..9a2ac8c3e5f1d149baf5de25c940e24a8acc9125 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" @@ -94,7 +93,6 @@ bool IsEngineOutput(absl::string_view name) { namespace convert { using absl::StrAppend; using absl::StrCat; -using ::tensorflow::str_util::Split; inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, nvinfer1::DataType* trt_dtype) { @@ -194,6 +192,15 @@ Status ValidateTensorProperties(const string& producer_node_type, *trt_dims = TensorShapeToTrtDims(shape, /*ignore_first_dim=*/true); *batch_size = shape.dim_size(0); + // Don't convert empty tensors (dim value of 0). + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) == 0) { + return errors::Unimplemented( + "Input tensor with shape ", shape.DebugString(), + " is an empty tensor, which is not supported by TRT"); + } + } + if (validation_only) return Status::OK(); // Following are validations at runtime. @@ -297,8 +304,8 @@ Status Converter::GetTrtBroadcastShape( const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; auto compute_output_dims = - [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, - int* output_dims_array, nvinfer1::Dims* output_dims) { + [](const TRT_TensorOrWeights& input, int broadcast_num_dims, + int* output_dims_array, nvinfer1::Dims* output_dims) { const nvinfer1::Dims input_dims = input.GetTrtDims(); std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); std::copy(input_dims.d, input_dims.d + input_dims.nbDims, @@ -380,6 +387,32 @@ tensorflow::Status CreateBroadcastableScalarConstant( return Status::OK(); } +// Convert an axis from TF format to TRT format while validating. TF format +// includes the batch dimension, while TRT does not. TF can also use negative +// indices. +// TODO(tmorris): Use this method in more ops. +tensorflow::Status ConvertAxis(int tf_axis, int trt_nb_dims, + absl::string_view node_name, int* trt_axis) { + const int tf_nb_dims = trt_nb_dims + 1; + // Check bounds. + if (tf_axis < -tf_nb_dims || tf_axis >= tf_nb_dims) { + return tensorflow::errors::InvalidArgument( + "Axis value of ", tf_axis, " is out of bounds, must be in range [", + -tf_nb_dims, ", ", tf_nb_dims, "), at ", node_name); + } + // Make negative axis positive. + if (tf_axis < 0) tf_axis += tf_nb_dims; + // Don't allow axis to be the batch dimension. + if (tf_axis == 0) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow manipulation of the batch dimension, at ", + node_name); + } + // Remove batch dimension. + *trt_axis = tf_axis - 1; + return Status::OK(); +} + inline bool DimsEqual(const nvinfer1::Dims& dim_l, const nvinfer1::Dims& dim_r) { if (dim_l.nbDims != dim_r.nbDims) { @@ -393,6 +426,15 @@ inline bool DimsEqual(const nvinfer1::Dims& dim_l, return true; } +bool AllLengthsEqual(const std::vector>& inputs) { + if (inputs.size() == 0) return true; + int length = inputs.at(0).size(); + for (int i = 1; i < inputs.size(); i++) { + if (inputs.at(i).size() != length) return false; + } + return true; +} + inline nvinfer1::Dims GetTrtDimsForTensor(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -530,6 +572,16 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return 0; } #endif +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: nvinfer1::DataType trt_dtype_; nvinfer1::Dims trt_dims_; @@ -921,7 +973,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) { for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights& output = outputs[i]; string output_name = node_def.name(); - if (i != 0) output_name = StrCat(output_name, ":", i); + if (i != 0) absl::StrAppend(&output_name, ":", i); // We need to check the name before setting it. If the input is one of the // engine input, setting the name here will overwrite engine input // bindings which will cause runtime error. @@ -2107,7 +2159,7 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { // Mark axes to remove by setting them to 0. TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); - if (squeeze_dims.size() == 0) { + if (squeeze_dims.empty()) { return tensorflow::errors::Unimplemented( "Squeeze is only implemented for explicit dims, at ", node_def.name()); } @@ -2152,100 +2204,73 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { return tensorflow::Status::OK(); } -// Gets the bounds (start or end) from the weights of a StridedSlice op. -tensorflow::Status GetStridedSliceBound(const std::vector& input_dims, - const TRT_ShapedWeights& bound_weights, - int mask, bool begin, string node_name, - std::vector* output_bound) { - const string bound_name = (begin) ? "begin" : "end"; - const int* weights_ptr = static_cast(bound_weights.GetValues()); - *output_bound = - std::vector(weights_ptr, weights_ptr + bound_weights.count()); - if (output_bound->size() != input_dims.size()) { - return tensorflow::errors::InvalidArgument( - "StridedSlice \"", bound_name, "\" specified ", - std::to_string(output_bound->size()), " dimensions, but input rank is ", - std::to_string(input_dims.size()), ", at ", node_name); - } - for (int i = 0; i < output_bound->size(); i++) { - if ((1 << i) & mask) { - // Apply mask. - (*output_bound)[i] = (begin) ? 0 : input_dims[i]; - // Masked bound will always result in a valid, non-negative bound, so we - // don't need the following checks. For the common case of using masks on - // a undefined batch dim (-1), we specifically don't want to do the - // following checks because they will erroneously detect an out of range - // bound or try to correct the negative value. - continue; - } - // Make sure bound is valid. - if (((*output_bound)[i] < -input_dims[i]) || - ((*output_bound)[i] > input_dims[i])) { +tensorflow::Status ConvertStridedSliceHelper(OpConverterParams* params, + const TRT_TensorOrWeights& input, + std::vector begin, + std::vector size, + const std::vector& stride) { + const auto& node_def = params->node_def; + // Get input dims. + nvinfer1::Dims dims = input.GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), -1); + // Check bounds. + for (int i = 1; i < input_dims.size(); i++) { + if (begin[i] < 0 || begin[i] > input_dims[i]) { + return tensorflow::errors::InvalidArgument( + "\"begin\" for dimension ", std::to_string(i), " in ", node_def.op(), + " is out of range, at ", node_def.name()); + } + const int end = begin[i] + size[i]; + if (end < 0 || end > input_dims[i]) { return tensorflow::errors::InvalidArgument( - bound_name, " value of ", std::to_string((*output_bound)[i]), - " for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at ", - node_name); + "\"begin\" + \"size\" for dimension ", std::to_string(i), " in ", + node_def.op(), " is out of range, at ", node_def.name()); } - // Convert negative values to their positive equivalent. - if ((*output_bound)[i] < 0) { - (*output_bound)[i] += input_dims[i]; + if (size[i] <= 0) { + return tensorflow::errors::InvalidArgument( + "\"size\" cannot be negative or zero for ", node_def.op(), ", at ", + node_def.name()); } } - return tensorflow::Status::OK(); -} +// TRT 5.1 adds a slice layer. For older versions, we attempt to use the +// padding layer with negative padding. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Use ISliceLayer. + nvinfer1::Dims begin_dims, size_dims, stride_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(size, &size_dims, + /*ignore_first_dim=*/true)); + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(stride, &stride_dims, + /*ignore_first_dim=*/true)); + if (params->validation_only) return Status::OK(); -tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { - const auto& inputs = params->inputs; - const auto& node_def = params->node_def; - TF_RETURN_IF_ERROR(CheckInputsWeights( - *params, - {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); - // Get input dims. - nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - if (inputs.at(0).is_tensor()) { - // Temporarily add batch dimension so that indexes line up properly. - input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); - } - if (input_dims.size() > 4) { - return tensorflow::errors::Unimplemented( - "StridedSlice is not implemented for tensors with rank > 4, at ", - node_def.name()); - } - TFAttrs attrs(node_def); - // Get begin and end bounds per axis. - std::vector begin, end; - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), - attrs.get("begin_mask"), true, - node_def.name(), &begin)); - TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), - attrs.get("end_mask"), false, - node_def.name(), &end)); - // Get strides per axis (must all be 1). - TRT_ShapedWeights stride_weights = inputs.at(3).weights(); - const int* stride_weights_ptr = static_cast(stride_weights.GetValues()); - std::vector strides(stride_weights_ptr, - stride_weights_ptr + stride_weights.count()); - for (int x : strides) { + nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( + *const_cast(input.tensor()), begin_dims, size_dims, + stride_dims); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +#else + // Use IPaddingLayer. + // Strides must be 1 in this case. + for (int x : stride) { if (x != 1) { return tensorflow::errors::Unimplemented( - "StridedSlice is only implemented for stride of 1, at ", + "Strides other than 1 are not supported with this version of TRT, " + "at ", node_def.name()); } } - // Unsupported mask options. - for (const string& attr : - {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { - int attr_val = attrs.get(attr); - if (attr_val != 0) { - return tensorflow::errors::Unimplemented( - attr, " is not supported for StridedSlice, at ", node_def.name()); - } + // Rank must be 2, 3 or 4. + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented(node_def.op(), + " for tensors with rank > 4 is " + "not supported in this version of " + "TRT, at ", + node_def.name()); } - - nvinfer1::ITensor* tensor = - const_cast(inputs.at(0).tensor()); // Reshape if necessary to 4-D, since IPaddingLayer requires a 4-D input. const bool need_reshape = (input_dims.size() != 4); int reshape_dims_added = 0; @@ -2255,7 +2280,7 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { while (input_dims.size() < 4) { input_dims.insert(input_dims.begin() + 1, 1); begin.insert(begin.begin() + 1, 0); - end.insert(end.begin() + 1, 1); + size.insert(size.begin() + 1, 1); reshape_dims_added++; } TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &reshape_dims, @@ -2263,23 +2288,22 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { } // Find dimensions which need to be sliced. std::vector pad_dims; - for (int i = 0; i < input_dims.size(); i++) { - if ((begin[i] != 0) || (end[i] != input_dims[i])) { - if (i == 0) { - return tensorflow::errors::Unimplemented( - "StridedSlice can't modify batch dim, at ", node_def.name()); - } else if ((end[i] - begin[i]) < 0) { - return tensorflow::errors::InvalidArgument( - "New size of sliced dimension is negative, at ", node_def.name()); - } + for (int i = 1; i < input_dims.size(); i++) { + if ((begin[i] != 0) || (begin[i] + size[i] != input_dims[i])) { pad_dims.push_back(i); } } - if (pad_dims.size() == 0) { - // No dimensions are changed. We could create a padding layer anyway with - // values of 0. + if (pad_dims.empty()) { + // No dimensions are changed, so this is a no-op. We could just return the + // input without creating a new layer. TRT will crash if an empty engine + // with no layers is attempted to be created, so we add a no-op shuffle to + // prevent our unit tests from breaking. + // TODO(tmorris): Allow empty engines in the unit tests and return the input + // as output here. if (params->validation_only) return Status::OK(); - params->outputs->push_back(inputs.at(0)); + nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( + *const_cast(input.tensor())); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return tensorflow::Status::OK(); } else if (pad_dims.size() == 1) { // Only one dim is modified but we have to have 2, mark a second dim which @@ -2292,16 +2316,19 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { } } else if (pad_dims.size() > 2) { return tensorflow::errors::Unimplemented( - "StridedSlice can only modify 2 dimensions, at ", node_def.name()); + node_def.op(), + " can only modify up to 2 dimensions in this version of TRT, at ", + node_def.name()); } std::sort(pad_dims.begin(), pad_dims.end()); // Convert to pre/post padding values. Since TRT does not have a StridedSlice - // or Slice layer, we instead create an IPaddingLayer with negative padding. + // or Slice layer prior to 5.1, we instead create an IPaddingLayer with + // negative padding. nvinfer1::DimsHW pre_padding, post_padding; for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; pre_padding.d[i] = -begin[axis]; - post_padding.d[i] = end[axis] - input_dims[axis]; + post_padding.d[i] = (begin[axis] + size[axis]) - input_dims[axis]; } // IPaddingLayer will always apply the padding to dims 2,3 (input format is @@ -2321,10 +2348,11 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { if (params->validation_only) return Status::OK(); // Start conversion. + nvinfer1::ITensor* tensor = const_cast(input.tensor()); if (need_reshape) { const nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), reshape_dims, &output_tensor)); + input, reshape_dims, &output_tensor)); tensor = const_cast(output_tensor); } if (need_transpose) { @@ -2333,7 +2361,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { tensor, transpose_order, &output_tensor)); tensor = const_cast(output_tensor); } - // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *const_cast(tensor), pre_padding, post_padding); @@ -2341,7 +2368,6 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); - // Restore transpose if (need_transpose) { const nvinfer1::ITensor* output_tensor = nullptr; @@ -2354,7 +2380,7 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { // Calculate output dimensions for (int i = 0; i < pad_dims.size(); i++) { const int axis = pad_dims[i]; - input_dims[axis] = end[axis] - begin[axis]; + input_dims[axis] = size[axis]; } // Remove added 1 dimensions for (int i = 0; i < reshape_dims_added; i++) { @@ -2378,6 +2404,135 @@ tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { params->outputs->push_back( TRT_TensorOrWeights(const_cast(tensor))); return tensorflow::Status::OK(); +#endif +} + +tensorflow::Status ConvertSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"input", false}, {"begin", true}, {"size", true}})); + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector size = inputs.at(2).weights().ToVector(); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + if (!AllLengthsEqual({input_dims, begin, size})) { + return tensorflow::errors::InvalidArgument( + "Length of begin and size arguments must equal rank of input for " + "Slice, at ", + node_def.name()); + } + // Check that batch dimension is unmodified. + const bool begin_is_modified = begin[0] != 0; + // If size[0]s is not -1, we can only know if the batch dimension is + // unmodified when the batch size is defined. When the batch size is + // undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool size_is_modified = + size[0] != -1 && (!batch_size_is_defined || + (batch_size_is_defined && size[0] != input_dims[0])); + if (begin_is_modified || size_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Size of -1 signifies to take all remaining elements. + for (int i = 1; i < input_dims.size(); i++) { + if (size[i] == -1) { + size[i] = input_dims[i] - begin[i]; + } + } + // Stride is 1 for all dims. + std::vector stride(begin.size(), 1); + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}})); + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + // Add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + // Get begin and end bounds per axis. + std::vector begin = inputs.at(1).weights().ToVector(); + std::vector end = inputs.at(2).weights().ToVector(); + std::vector stride = inputs.at(3).weights().ToVector(); + if (!AllLengthsEqual({input_dims, begin, end, stride})) { + return tensorflow::errors::InvalidArgument( + "Length of begin, end, and stride arguments must equal rank of input " + "for StridedSlice, at ", + node_def.name()); + } + // Unsupported mask options. + TFAttrs attrs(node_def); + for (const string& attr : + {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int attr_val = attrs.get(attr); + if (attr_val != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not supported for StridedSlice, at ", node_def.name()); + } + } + const int begin_mask = attrs.get("begin_mask"); + const int end_mask = attrs.get("end_mask"); + // Check that batch dimension is unmodified. + const bool begin_is_modified = !(begin_mask & 1) && begin[0] != 0; + const bool stride_is_modified = stride[0] != 1; + // If the batch size is -1 and the end mask is not set, we can only know if + // the batch dimension is unmodified when the batch size is defined. When the + // batch size is undefined, we don't convert to be safe. + const bool batch_size_is_defined = input_dims[0] > 0; + const bool end_is_modified = + !(end_mask & 1) && (!batch_size_is_defined || + (batch_size_is_defined && end[0] != input_dims[0])); + if (begin_is_modified || stride_is_modified || end_is_modified) { + return tensorflow::errors::Unimplemented( + "TensorRT does not allow modifications to the batch dimension, at ", + node_def.name()); + } + // Standarize begin and end bounds by applying masks, making negative values + // positive, and correcting out of bounds ranges (StridedSlice does this + // silently). + for (int i = 1; i < input_dims.size(); i++) { + // Begin + if ((1 << i) & begin_mask) { + begin[i] = 0; + } else if (begin[i] < 0) { + begin[i] += input_dims[i]; + } + begin[i] = std::max(0, std::min(begin[i], input_dims[i])); + // End + if ((1 << i) & end_mask) { + end[i] = input_dims[i]; + } else if (end[i] < 0) { + end[i] += input_dims[i]; + } + end[i] = std::max(0, std::min(end[i], input_dims[i])); + } + // Negative or zero strides currently not supported. + for (int i = 0; i < input_dims.size(); i++) { + if (stride[i] <= 0) { + return tensorflow::errors::Unimplemented( + "Negative or zero stride values are not supported for StridedSlice, " + "at ", + node_def.name()); + } + } + // TRT Slice layer uses (begin, size) instead of (begin, end) + std::vector size(input_dims.size()); + for (int i = 0; i < input_dims.size(); i++) { + // Divide by stride (round up) + size[i] = (end[i] - begin[i] + stride[i] - 1) / stride[i]; + } + return ConvertStridedSliceHelper(params, inputs.at(0), begin, size, stride); } tensorflow::Status ConvertConv2D(OpConverterParams* params) { @@ -2947,58 +3102,104 @@ Status ConvertBinary(OpConverterParams* params) { return status; } -tensorflow::Status ConvertUnary(OpConverterParams* params) { +tensorflow::Status ConvertRsqrt(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; - static const std::unordered_map ops{ - {"Neg", nvinfer1::UnaryOperation::kNEG}, - {"Exp", nvinfer1::UnaryOperation::kEXP}, - {"Log", nvinfer1::UnaryOperation::kLOG}, - {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, - {"Abs", nvinfer1::UnaryOperation::kABS}, - {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, - }; TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + if (params->validation_only) return tensorflow::Status::OK(); - // TODO(jie): check type - const nvinfer1::ITensor* tensor = nullptr; - TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), inputs.at(0).GetTrtDims(), &tensor)); + // TODO(tmorris): params->converter is null during validation. Allow + // precision_mode and use_calibration to be accessed during validation and + // include this check in validation. + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here + if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && + !params->converter->use_calibration()) { + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); + } + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + // Sqrt + nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( + *const_cast(tensor), nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); + // Recip + nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( + *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0))); + return tensorflow::Status::OK(); +} - nvinfer1::IUnaryLayer* layer; - if (node_def.op() == "Rsqrt") { - // We will need a quantization range for intermediate tensor if not using - // calibration. - // - // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) - // ^ - // need range here - if (params->converter->precision_mode() == TrtPrecisionMode::INT8 && - !params->converter->use_calibration()) { - return errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration for Rsqrt, consider replacing with " - "Sqrt -> FakeQuant -> Reciprocal ops, at ", - node_def.name()); - } - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kSQRT); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - tensor = layer->getOutput(0); - layer = params->converter->network()->addUnary( - *const_cast(tensor), - nvinfer1::UnaryOperation::kRECIP); - } else if (ops.count(node_def.op()) != 0) { - layer = params->converter->network()->addUnary( - *const_cast(tensor), ops.at(node_def.op())); - } else { - return tensorflow::errors::InvalidArgument( - "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); +const std::unordered_map* +UnaryOperationMap() { + static auto* const m = + new std::unordered_map({ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + {"Sin", nvinfer1::UnaryOperation::kSIN}, + {"Cos", nvinfer1::UnaryOperation::kCOS}, + {"Tan", nvinfer1::UnaryOperation::kTAN}, + {"Sinh", nvinfer1::UnaryOperation::kSINH}, + {"Cosh", nvinfer1::UnaryOperation::kCOSH}, + {"Asin", nvinfer1::UnaryOperation::kASIN}, + {"Acos", nvinfer1::UnaryOperation::kACOS}, + {"Atan", nvinfer1::UnaryOperation::kATAN}, + {"Asinh", nvinfer1::UnaryOperation::kASINH}, + {"Acosh", nvinfer1::UnaryOperation::kACOSH}, + {"Atanh", nvinfer1::UnaryOperation::kATANH}, + {"Ceil", nvinfer1::UnaryOperation::kCEIL}, + {"Floor", nvinfer1::UnaryOperation::kFLOOR}, +#endif + }); + return m; +} + +tensorflow::Status ConvertUnary(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto op_pair = UnaryOperationMap()->find(node_def.op()); + if (op_pair == UnaryOperationMap()->end()) { + return tensorflow::errors::Unimplemented( + "Unary op: ", node_def.op(), " not supported at: ", node_def.name()); } + if (params->validation_only) return tensorflow::Status::OK(); + // Start conversion. + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary( + *const_cast(tensor), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + + // Set quantization ranges. + if (node_def.op() == "Sin" || node_def.op() == "Cos") { + params->converter->ProvideQuantizationRange(output_tensor, -1.0f, 1.0f); + } else if (node_def.op() == "Asin" || node_def.op() == "Atan") { + params->converter->ProvideQuantizationRange(output_tensor, -M_PI_2, M_PI_2); + } else if (node_def.op() == "Acos") { + params->converter->ProvideQuantizationRange(output_tensor, 0.0f, M_PI); + } else if (node_def.op() == "Neg" || node_def.op() == "Abs") { + // Neg and Abs will have same range as input since TRT uses symmetric + // quantization. + // TODO(tmorris): Should we infer ranges for Ceil and Floor as well? + params->converter->MarkQuantizationRangesAsInferrable( + const_cast(tensor), output_tensor); + } params->outputs->push_back( TRT_TensorOrWeights(const_cast(output_tensor))); return tensorflow::Status::OK(); @@ -3139,7 +3340,7 @@ tensorflow::Status ConvertPad(OpConverterParams* params) { } // No padding at all, we should exit - if (pad_index.size() == 0) { + if (pad_index.empty()) { params->outputs->push_back(inputs.at(0)); return tensorflow::Status::OK(); } @@ -3329,7 +3530,7 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { TRT_ShapedWeights dummy_power_weights(parameter_type); size_t nweight = 0; for (int i = 1; i < 5; i++) { - nweight = std::max(nweight, (size_t)inputs.at(i).weights().count()); + nweight = std::max(nweight, inputs.at(i).weights().count()); } TRT_ShapedWeights* ptr_shape_weights = nullptr; for (int i = 1; i < 5; i++) { @@ -3414,6 +3615,29 @@ tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status ConvertGather(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, {{"params", false}, {"indices", false}, {"axis", true}})); + absl::Span axis = inputs.at(2).weights().GetSpan(); + if (axis.size() != 1) { + return tensorflow::errors::InvalidArgument( + "Axis for GatherV2 must be a scalar, at ", node_def.name()); + } + int trt_axis = 0; + TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims, + node_def.name(), &trt_axis)); + if (params->validation_only) return Status::OK(); + + nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( + *const_cast(inputs.at(0).tensor()), + *const_cast(inputs.at(1).tensor()), trt_axis); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + tensorflow::Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights tensor_input, TRT_ShapedWeights weights_raw, @@ -3644,11 +3868,14 @@ static void RegisterValidatableOpConverters( (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; (*registration)["ExpandDims"] = ConvertExpandDims; + (*registration)["GatherV2"] = ConvertGather; (*registration)["LeakyRelu"] = ConvertLeakyRelu; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; (*registration)["Reshape"] = ConvertReshape; + (*registration)["Rsqrt"] = ConvertRsqrt; + (*registration)["Slice"] = ConvertSlice; (*registration)["Square"] = ConvertSquare; (*registration)["Squeeze"] = ConvertSqueeze; (*registration)["StridedSlice"] = ConvertStridedSlice; @@ -3673,6 +3900,9 @@ static void RegisterValidatableOpConverters( for (auto normalization_op_type : {"FusedBatchNorm", "FusedBatchNormV2"}) { (*registration)[normalization_op_type] = ConvertFusedBatchNorm; } + for (auto unary_op_pair : *UnaryOperationMap()) { + (*registration)[unary_op_pair.first] = ConvertUnary; + } } void TrtNodeValidator::RegisterOpValidators() { @@ -3685,14 +3915,6 @@ void Converter::RegisterOpConverters() { op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Reciprocal"] = ConvertUnary; - op_registry_["Exp"] = ConvertUnary; - op_registry_["Log"] = ConvertUnary; - op_registry_["Sqrt"] = ConvertUnary; - op_registry_["Abs"] = ConvertUnary; - op_registry_["Neg"] = ConvertUnary; - op_registry_["Sum"] = ConvertReduce; op_registry_["Prod"] = ConvertReduce; op_registry_["Max"] = ConvertReduce; @@ -3722,8 +3944,12 @@ tensorflow::Status ConvertGraphDefToEngine( builder->setMaxWorkspaceSize(max_workspace_size_bytes); builder->setGpuAllocator(allocator); if (precision_mode == TrtPrecisionMode::FP16) { - builder->setHalf2Mode(true); + builder->setFp16Mode(true); } else if (precision_mode == TrtPrecisionMode::INT8) { + // Setting FP16 mode as well allows TRT to also consider FP16 kernels and + // use them in situations where they are faster than INT8 or where INT8 is + // not supported for a given layer. + builder->setFp16Mode(true); builder->setInt8Mode(true); if (use_calibration) { builder->setInt8Calibrator(calibrator); @@ -3899,7 +4125,7 @@ tensorflow::Status ConvertSegmentToGraphDef( local_scope = GetCommonNameScope(local_scope, node->name()); old_to_new_id_map[node->id()] = segment_def->node_size(); auto snode = segment_def->add_node(); - snode->CopyFrom(node->def()); + *snode = node->def(); VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index d1e30eb848bd6ab62719ca6da561d14b05d8537d..7b37173090519ff6fadd956942d7ea12a0644981 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -190,6 +190,17 @@ class TRT_ShapedWeights { string DebugString() const; + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), count()); + } + + template + std::vector ToVector() const { + auto span = GetSpan(); + return std::vector(span.data(), span.data() + span.size()); + } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; // Note: shape.type[] is not used. tensorflow::DataType type_; @@ -560,6 +571,9 @@ class Converter { friend class OpConverterTest; }; +// Map of all supported UnaryOperations +const std::unordered_map* UnaryOperationMap(); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 77221f6d9a42a165e8f9e322e1f876b02f4db59f..0aa48913f463b55a4252b01281c3ed3feed35539 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -60,6 +60,7 @@ namespace convert { using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::NanSensitiveFloatNear; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -241,6 +242,16 @@ class FakeITensor : public nvinfer1::ITensor { float getDynamicRange() const override { return dynamic_range_; } #endif +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + bool dynamicRangeIsSet() const override { return true; } + + void resetDynamicRange() override {} + + float getDynamicRangeMin() const override { return 0.f; } + + float getDynamicRangeMax() const override { return 0.f; } +#endif + private: string name_; nvinfer1::Dims dims_; @@ -1439,11 +1450,6 @@ TEST_F(OpConverterTest, ConvertReshape) { } struct TestParams { - TestParams(int input_batch_size, const std::vector& input_tensor_dims, - const std::vector& input_shape) - : batch_size(input_batch_size), - tensor_dims(input_tensor_dims), - shape(input_shape) {} int batch_size; std::vector tensor_dims; std::vector shape; @@ -2381,11 +2387,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) { } struct TestParams { - TestParams(const std::vector& input_dims, int axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; int axis; std::vector expected_output_dims; @@ -2498,11 +2499,6 @@ TEST_F(OpConverterTest, ConvertSqueeze) { } struct TestParams { - TestParams(const std::vector& input_dims, const std::vector& axis, - const std::vector& expected_output_dims) - : input_dims(input_dims), - axis(axis), - expected_output_dims(expected_output_dims) {} std::vector input_dims; std::vector axis; std::vector expected_output_dims; @@ -2621,46 +2617,62 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("strides", {4}, {1, 1, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "StridedSlice can't modify batch dim, at my_strided_slice"); + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); } { - // Stride is not 1, should fail. + // Dynamic batch size without end_mask, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); AddTestWeights("begin", {4}, {0, 0, 0, 0}); AddTestWeights("end", {4}, {1, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 2, -1, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "StridedSlice is only implemented for stride of " - "1, at my_strided_slice"); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_strided_slice"); } { - // Begin out of bounds, should fail. + // Dynamic batch size but using end_mask, ok. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*end_mask=*/1); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 2}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion(node_def); + } +// TRT 5.1+ supports strides +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + { + // Negative strides, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); - AddTestWeights("begin", {4}, {1, 2, 3, 4}); - AddTestWeights("end", {4}, {0, 1, 2, 3}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "begin value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, -1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Negative or zero stride values are not " + "supported for StridedSlice, at " + "my_strided_slice"); } +#else { - // End out of bounds, should fail. + // Stride is not 1, should fail. Reset(); NodeDef node_def = get_strided_slice_nodedef(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("begin", {4}, {0, 0, 0, 0}); - AddTestWeights("end", {4}, {1, 2, 3, 4}); - AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "end value of 2 for StridedSlice is invalid, must be in the range " - "[-dim_size(i), dim_size(i)], at my_strided_slice"); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, 1, 3}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Strides other than 1 are not supported with " + "this version of TRT, at my_strided_slice"); } +#endif { // Size of sliced dim is negative, should fail. Reset(); @@ -2669,126 +2681,183 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { AddTestWeights("begin", {4}, {0, 0, 2, 0}); AddTestWeights("end", {4}, {1, 1, 0, 3}); AddTestWeights("strides", {4}, {1, 1, 1, 1}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "New size of sliced dimension is negative, at my_strided_slice"); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "\"size\" cannot be negative or zero for " + "StridedSlice, at my_strided_slice"); } struct TestParams { - TestParams(const std::vector& input_dims, - const std::vector& expected_output_dims, - const std::vector& begin, const std::vector& end, - const std::vector& begin_mask, - const std::vector& end_mask, - const std::vector& expected_output) - : input_dims(input_dims), - expected_output_dims(expected_output_dims), - begin(begin), - end(end), - expected_output(expected_output) { - // Masks are provided in terms of vectors for readability. Convert them to - // binary here. - this->begin_mask = 0; - for (int i = 0; i < begin_mask.size(); i++) { - if (begin_mask[i]) this->begin_mask |= (1 << i); - } - this->end_mask = 0; - for (int i = 0; i < end_mask.size(); i++) { - if (end_mask[i]) this->end_mask |= (1 << i); - } - } - std::vector input_dims; - std::vector expected_output_dims; std::vector begin; std::vector end; + std::vector strides; int begin_mask; int end_mask; - std::vector expected_output; + std::vector expected_output_dims; + std::vector expected_output; + }; + + auto get_mask = [](const std::vector& mask) { + int result = 0; + for (int i = 0; i < mask.size(); i++) { + if (mask[i]) result += (1 << i); + } + return result; }; + // Same input is used for all tests. + const std::vector ok_input = {1, 2, 3, 4, 5, 6}; + +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + const int kStridedSliceOKCases = 23; +#else + const int kStridedSliceOKCases = 19; +#endif // Ok. - const int kStridedSliceOKCases = 18; TestParams ok_params[kStridedSliceOKCases] = { - // 2D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, - /*expected_output=*/{5, 6}}, - // 2D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, - /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{5, 6}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, - /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, - /*expected_output=*/{5, 6}}, - // 2D Crop, with reshape. - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, - /*expected_output=*/{1, 2}}, - TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, - /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, - /*expected_output=*/{5, 6}}, - // 1D Crop. - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, - /*expected_output=*/{1, 2, 4, 5}}, - TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, - /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with transpose. - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, - /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, - /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, - /*expected_output=*/{4, 5, 6}}, - // 1D Crop, with reshape. - TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, - /*begin=*/{0, 0}, /*end=*/{0, 3}, - /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, - /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, - /*expected_output=*/{3, 4, 5}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{3, 4, 5}}, - // Negative axis. - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, - /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{1, 2, 3}}, - TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, - /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, - /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, - /*expected_output=*/{1, 2, 3, 4, 5}}, + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, 0, 0}, + /*end=*/{0, 0, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), + /*expected_output_dims=*/{1, 1, 2}, /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 2, 1}, + /*expected_output=*/{5, 6}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{ + /*input_dims=*/{2, 1, 3}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), /*expected_output_dims=*/{1, 1, 2}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1}), + /*expected_output_dims=*/{1, 2}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 0}), /*expected_output_dims=*/{1, 2, 2}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{ + /*input_dims=*/{1, 2, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 1, 3}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{ + /*input_dims=*/{2, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 1, 1}), /*expected_output_dims=*/{1, 3, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, /*strides=*/{1, 1}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 1, 0}), + /*expected_output_dims=*/{1, 3}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{3, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 1}), + /*expected_output_dims=*/{5, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + // Clamp out of bounds begin and end. + TestParams{/*input_dims=*/{1, 2, 3}, /*begin=*/{0, 0, -9999, -9}, + /*end=*/{0, 1, 1000, 4}, /*strides=*/{1, 1, 1, 1}, + /*begin_mask=*/get_mask({0, 0, 0, 0}), + /*end_mask=*/get_mask({1, 0, 0, 0}), + /*expected_output_dims=*/{1, 2, 3}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}}, +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) + // Strides + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 5}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 0}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{1, 3, 5}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 1}, /*end=*/{0, 6}, /*strides=*/{1, 2}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{3}, + /*expected_output=*/{2, 4, 6}}, + TestParams{/*input_dims=*/{6}, + /*begin=*/{0, 2}, /*end=*/{0, 6}, /*strides=*/{1, 3}, + /*begin_mask=*/get_mask({0, 0}), /*end_mask=*/get_mask({1, 0}), + /*expected_output_dims=*/{2}, + /*expected_output=*/{3, 6}}, +#endif }; for (int i = 0; i < kStridedSliceOKCases; i++) { @@ -2801,16 +2870,18 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { ok_params[i].begin); AddTestWeights("end", {static_cast(ok_params[i].end.size())}, ok_params[i].end); - std::vector strides(ok_params[i].input_dims.size(), 1); - AddTestWeights("strides", {static_cast(strides.size())}, - strides); + AddTestWeights("strides", + {static_cast(ok_params[i].strides.size())}, + ok_params[i].strides); RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); - const DataVec input_data{ - {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + const DataVec input_data{{"input", test::AsTensor(ok_input)}}; DataVec output_data{ {"my_strided_slice", ConstructTensor(ok_params[i].expected_output.size())}}; @@ -2820,6 +2891,148 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { } } +TEST_F(OpConverterTest, ConvertSlice) { + // Get nodedef for Slice layer. + auto get_slice_nodedef = []() -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto size = ops::Placeholder(s.WithOpName("size"), DT_INT32); + auto slice = ops::Slice(s.WithOpName("my_slice"), input, begin, size); + return slice.operation.node()->def(); + }; + + { + // Begin is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, -1, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Begin is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 3, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" for dimension 2 in Slice is out of range, at my_slice"); + } + { + // Size is below bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, -2}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 3 in Slice is out of range, at " + "my_slice"); + } + { + // Size is above bounds, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 3, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "\"begin\" + \"size\" for dimension 2 in Slice is out of range, at " + "my_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size with size[0] not -1, should fail. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {1, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "TensorRT does not allow modifications to the batch dimension, at " + "my_slice"); + } + { + // Dynamic batch size but using size[0] of -1, ok. + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}, /*batch_size=*/-1); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("size", {4}, {-1, 1, 2, 2}); + RunValidationAndConversion(node_def); + } + + struct TestParams { + std::vector input_dims; + std::vector begin; + std::vector size; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Ok. + const int kSliceOKCases = 5; + TestParams ok_params[kSliceOKCases] = { + TestParams{{1, 2, 3}, + {0, 0, 0, 0}, + {-1, -1, -1, -1}, + {1, 2, 3}, + {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, 1, 2, 3}, {1, 2, 3}, {1, 2, 3, 4, 5, 6}}, + TestParams{ + {1, 2, 3}, {0, 0, 0, 0}, {1, -1, 2, 2}, {1, 2, 2}, {1, 2, 4, 5}}, + TestParams{{6}, {0, 1}, {1, 5}, {5}, {2, 3, 4, 5, 6}}, + TestParams{{6}, {0, 1}, {-1, 3}, {3}, {2, 3, 4}}, + }; + + for (int i = 0; i < kSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_slice_nodedef(); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", + {static_cast(ok_params[i].begin.size())}, + ok_params[i].begin); + AddTestWeights("size", {static_cast(ok_params[i].size.size())}, + ok_params[i].size); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_slice", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor({1, 2, 3, 4, 5, 6})}}; + DataVec output_data{{"my_slice", ConstructTensor( + ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} + TEST_F(OpConverterTest, ConvertConv2D) { { // Input list is empty, should fail. @@ -2956,27 +3169,6 @@ TEST_F(OpConverterTest, ConvertConv2D) { } struct TestParams { - TestParams(const std::vector& input_dims, - const std::vector& input, - const std::vector& filter_dims, - const std::vector& filter, - const std::vector& strides, const string& padding, - const string& data_format, const std::vector& dilations, - bool is_conv2d_backprop_input, - const std::vector& expected_output_dims, - const std::vector& expected_output) - : input_dims(input_dims), - input(input), - filter_dims(filter_dims), - filter(filter), - strides(strides), - padding(padding), - data_format(data_format), - dilations(dilations), - is_conv2d_backprop_input(is_conv2d_backprop_input), - expected_output_dims(expected_output_dims), - expected_output(expected_output) {} - std::vector input_dims; std::vector input; std::vector filter_dims; @@ -3163,6 +3355,294 @@ TEST_F(OpConverterTest, ConvertTopK) { } } +template +void TestConvertGather(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), dtype); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + + struct TestParams { + std::vector params_dims; + std::vector indices_dims; + std::vector indices; + int axis; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Input is the same {1, 2, 3, 4, 5, 6} for all cases. + const int kGatherOKCases = 5; + TestParams ok_params[kGatherOKCases] = { + // Vector indices (output is rank(params)). + TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}}, + TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}}, + TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}}, + TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}}, + // Higher rank indices (output is rank(params) + rank(indices) - 1). + TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}}, + }; + + // Ok. + for (int i = 0; i < kGatherOKCases; i++) { + test->Reset(); + test->AddTestTensor("params", ok_params[i].params_dims, 1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("indices", ok_params[i].indices_dims, 1, + nvinfer1::DataType::kINT32); + test->AddTestWeights("axis", {1}, {ok_params[i].axis}); + test->RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_gather", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + // Create input in CType and convert expected output to CType. + std::vector inputs = {CType(1), CType(2), CType(3), + CType(4), CType(5), CType(6)}; + std::vector converted_expected_output( + ok_params[i].expected_output.begin(), + ok_params[i].expected_output.end()); + + const DataVec input_data{ + {"params", test::AsTensor(inputs)}, + {"indices", test::AsTensor(ok_params[i].indices)}}; + DataVec output_data{ + {"my_gather", + ConstructTensor(ok_params[i].expected_output.size())}}; + test->BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(converted_expected_output)); + } +} + +TEST_F(OpConverterTest, ConvertGather) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_gather", "GatherV2", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "GatherV2 got 0 inputs but expected 3, at my_gather"); + } + + // Get the NodeDef for GatherV2. + Scope s = Scope::NewRootScope(); + auto params = ops::Placeholder(s.WithOpName("params"), DT_FLOAT); + auto indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + auto gather = ops::GatherV2(s.WithOpName("my_gather"), params, indices, axis); + const NodeDef& node_def = gather.operation.node()->def(); + { + // Axis is a tensor, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestTensor("axis", {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for GatherV2 must be a constant, at my_gather"); + } + { + // Axis is out of bounds, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {4}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in " + "range [-4, 4), at my_gather"); + } + { + // Axis is batch dimension, should fail. + Reset(); + AddTestTensor("params", {1, 2, 3}); + AddTestTensor("indices", {2}); + AddTestWeights("axis", {1}, {0}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_gather"); + } + + Reset(); + TestConvertGather(this); + TestConvertGather(this); + TestConvertGather(this); +} + +TEST_F(OpConverterTest, ConvertUnary) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_unary", "Neg", {}); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Neg got 0 inputs but expected 1, at my_unary"); + } + { + // Input is weights, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto neg = ops::Neg(s.WithOpName("my_unary"), input); + const NodeDef& node_def = neg.operation.node()->def(); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"x\" for Neg must be a tensor, at my_unary"); + } + + // Get nodedef for unary layer. + auto get_unary_nodedef = [](string op_name) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + if (op_name == "Abs") { + auto unary = ops::Abs(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acos") { + auto unary = ops::Acos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Acosh") { + auto unary = ops::Acosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asin") { + auto unary = ops::Asin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Asinh") { + auto unary = ops::Asinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atan") { + auto unary = ops::Atan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Atanh") { + auto unary = ops::Atanh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Ceil") { + auto unary = ops::Ceil(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cos") { + auto unary = ops::Cos(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Cosh") { + auto unary = ops::Cosh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Exp") { + auto unary = ops::Exp(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Floor") { + auto unary = ops::Floor(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Log") { + auto unary = ops::Log(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Neg") { + auto unary = ops::Neg(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Reciprocal") { + auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Rsqrt") { + auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sin") { + auto unary = ops::Sin(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sinh") { + auto unary = ops::Sinh(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Sqrt") { + auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } else if (op_name == "Tan") { + auto unary = ops::Tan(s.WithOpName("my_unary"), input); + return unary.operation.node()->def(); + } + EXPECT_TRUE(false); + return NodeDef(); + }; + // Get expected output for unary layer. + auto get_unary_output = [](string op_name, float input) -> float { + if (op_name == "Abs") { + return std::abs(input); + } else if (op_name == "Acos") { + return std::acos(input); + } else if (op_name == "Acosh") { + return std::acosh(input); + } else if (op_name == "Asin") { + return std::asin(input); + } else if (op_name == "Asinh") { + return std::asinh(input); + } else if (op_name == "Atan") { + return std::atan(input); + } else if (op_name == "Atanh") { + return std::atanh(input); + } else if (op_name == "Ceil") { + return std::ceil(input); + } else if (op_name == "Cos") { + return std::cos(input); + } else if (op_name == "Cosh") { + return std::cosh(input); + } else if (op_name == "Exp") { + return std::exp(input); + } else if (op_name == "Floor") { + return std::floor(input); + } else if (op_name == "Log") { + return std::log(input); + } else if (op_name == "Neg") { + return -input; + } else if (op_name == "Reciprocal") { + return 1.0 / input; + } else if (op_name == "Rsqrt") { + return 1.0 / std::sqrt(input); + } else if (op_name == "Sin") { + return std::sin(input); + } else if (op_name == "Sinh") { + return std::sinh(input); + } else if (op_name == "Sqrt") { + return std::sqrt(input); + } else if (op_name == "Tan") { + return std::tan(input); + } + EXPECT_TRUE(false); + return 0; + }; + + // Get list of ops to test. + std::vector ops_to_test; + // Add all ops supported by ConvertUnary. + auto* map = UnaryOperationMap(); + ops_to_test.reserve(map->size()); + for (auto& pair : *map) { + ops_to_test.push_back(pair.first); + } + // Add other unary ops to test. + ops_to_test.push_back("Rsqrt"); + // Ok. + for (string op_name : ops_to_test) { + Reset(); + NodeDef node_def = get_unary_nodedef(op_name); + AddTestTensor("input", {1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + + const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + const DataVec input_data{{"input", test::AsTensor(input)}}; + DataVec output_data{{"my_unary", ConstructTensor(6)}}; + BuildAndRun(input_data, &output_data); + for (int i = 0; i < input.size(); ++i) { + const float expected_output = get_unary_output(op_name, input[i]); + EXPECT_THAT(GetSpanForData(output_data[0])[i], + NanSensitiveFloatNear(expected_output, 0.0001)); + } + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index f36aa558ea2ea463983caf163e17f83ae1c38f40..0eedfcacb4c11c8dc63fcfc13f044586b99b3c76 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -87,7 +87,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "type = " << cluster->type(); LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps(); const auto dev_names = cluster->GetDeviceNames(); - if (dev_names.size()) { + if (!dev_names.empty()) { LOG(INFO) << offset << " Device names:"; for (const auto s : dev_names) { LOG(INFO) << offset2 << s; @@ -103,7 +103,7 @@ void TRTOptimizationPass::PrintDebugInfo( } const auto dev_props = cluster->GetDevices(); - if (dev_props.size()) { + if (!dev_props.empty()) { LOG(INFO) << offset << "Device properties:"; for (auto k : dev_props) { LOG(INFO) << offset2 << k.first; @@ -131,7 +131,7 @@ void TRTOptimizationPass::PrintDebugInfo( } } LOG(INFO) << "item: " << item.id; - if (item.feed.size()) { + if (!item.feed.empty()) { LOG(INFO) << offset << "Feeds :"; for (const auto& f : item.feed) { const auto& shape = f.second.shape(); @@ -140,7 +140,7 @@ void TRTOptimizationPass::PrintDebugInfo( } else { LOG(INFO) << offset << "No Feeds"; } - if (item.fetch.size()) { + if (!item.fetch.empty()) { LOG(INFO) << offset << "Fetches :"; for (const auto& f : item.fetch) { LOG(INFO) << offset2 << f; @@ -149,7 +149,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << offset << "No Fetches"; } - if (item.init_ops.size()) { + if (!item.init_ops.empty()) { LOG(INFO) << offset << "init ops :"; for (const auto& f : item.init_ops) { LOG(INFO) << offset2 << f; @@ -160,7 +160,7 @@ void TRTOptimizationPass::PrintDebugInfo( LOG(INFO) << "Save Op = " << item.save_op; LOG(INFO) << "Restore Op = " << item.restore_op; LOG(INFO) << "save_restore_loc_tensor = " << item.save_restore_loc_tensor; - if (item.keep_ops.size()) { + if (!item.keep_ops.empty()) { LOG(INFO) << offset << "keep ops :"; for (const auto& f : item.keep_ops) { LOG(INFO) << offset2 << f; @@ -197,7 +197,7 @@ tensorflow::Status TRTOptimizationPass::Optimize( PrintDebugInfo(cluster, item); } int max_dim = -1; - if (item.feed.size()) { + if (!item.feed.empty()) { for (const auto& f : item.feed) { const auto& shape = f.second.shape(); if (shape.dims() > 0) { diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc index eae1f8e7525f1816d1c50072ebe4ba6713c96e47..81406b6e301ca350a3e52c97f5fcb575e88c3a90 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op.cc @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ -#define TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ - #include #include @@ -70,4 +67,3 @@ REGISTER_KERNEL_BUILDER(Name("GetSerializedResourceOp").Device(DEVICE_GPU), #endif // GOOGLE_TENSORRT #endif // GOOGLE_CUDA -#endif // TENSORFLOW_COMPILER_TF2TENSORRT_KERNELS_GET_SERIALIZED_RESOURCE_OP_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index bc5335ef5aa35633a68e69f7de7903b4f498531a..f6d387c59cd04aa5c7ccad610290b7b1f1d2b11f 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -216,8 +215,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("use_calibration", &use_calibration_)); calibration_mode_ = (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 && - calibration_data.size() == 0); - if (calibration_data.size()) { + calibration_data.empty()); + if (!calibration_data.empty()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); } @@ -254,6 +253,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, opts.rendezvous = ctx->rendezvous(); opts.cancellation_manager = ctx->cancellation_manager(); opts.runner = ctx->runner(); + inputs.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); i++) { inputs.push_back(ctx->input(i)); } @@ -294,27 +294,6 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, return this->AllocateCalibrationResources(ctx, cr); }})); tensorflow::core::ScopedUnref calib_sc(calib_res); - // TODO(aaroey): here we also add the resource to the ResourceMgr singleton. - // This is needed before we migrate all uses of calib_graph_to_infer_graph() - // to the new calibration workflow. After that we'll remove this block. - { - auto deprecated_rm = - TRTResourceManager::instance()->getManager("TRTCalibration"); - TRTCalibrationResource* copied_resource = nullptr; - // Check whether the resource exists, and create it if not. - if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource) - .ok()) { - // Do nothing if the resource exists. - copied_resource->Unref(); - } else { - copied_resource = calib_res; - // Increase the refcount by 1 then transfer the ownership of that refcount - // to the ResourceMgr singleton. - copied_resource->Ref(); - OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator", - copied_resource)); - } - } int num_inputs = ctx->num_inputs(); // Pass input data to calibrator std::unordered_map input_data; @@ -385,8 +364,9 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } // Get shapes of inputs to engine. std::vector input_shapes; + input_shapes.reserve(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { - input_shapes.emplace_back(ctx->input(i).shape()); + input_shapes.push_back(ctx->input(i).shape()); } EngineContext* engine_context = GetEngine(input_shapes, ctx); if (!engine_context->cuda_engine) { @@ -433,7 +413,8 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, auto dtype = cuda_engine->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: LOG(ERROR) << "FP16 inputs are not supported yet!"; @@ -442,10 +423,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, LOG(ERROR) << "INT8 inputs are not supported yet!"; return kRetry; case nvinfer1::DataType::kINT32: - buffers[binding_index] = (void*)(input_tensor.flat().data()); + buffers[binding_index] = + const_cast(input_tensor.flat().data()); break; default: - LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); return kRetry; } } @@ -484,7 +466,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: LOG(WARNING) << "half size is not supported yet!"; @@ -494,7 +476,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, return kRetry; case nvinfer1::DataType::kINT32: buffers[binding_index] = - reinterpret_cast(output_tensor->flat().data()); + const_cast(output_tensor->flat().data()); break; default: LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); @@ -616,11 +598,11 @@ EngineContext* TRTEngineOp::GetEngine( LOG(INFO) << "Building a new TensorRT engine for " << name() << " input shapes: " << TensorShapeUtils::ShapeListString(engine_input_shapes); + // Convert to partial shapes - std::vector partial_shapes; - for (int i = 0; i < engine_input_shapes.size(); i++) { - partial_shapes.emplace_back(engine_input_shapes[i]); - } + std::vector partial_shapes(engine_input_shapes.begin(), + engine_input_shapes.end()); + // Up to this point, calibrator_ can never be empty, since otherwise it // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py index 86bfabf99e08a8e447a28504c72eebca4d3a582c..25fb3a13db9911673bac04652b8ed8ba842be93c 100644 --- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py +++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py @@ -18,17 +18,52 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading + import platform +from tensorflow.python.framework import errors + +_trt_ops_so = None +_module_lock = threading.Lock() + + +def load_trt_ops(): + """Load TF-TRT op libraries so if it hasn't been loaded already.""" + global _trt_ops_so + + if platform.system() == "Windows": + raise RuntimeError("Windows platforms are not supported") + + with _module_lock: + if _trt_ops_so: + return -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import * + try: + # pylint: disable=g-import-not-at-top,unused-variable + # This registers the TRT ops, it doesn't require loading TRT library. + from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op + # pylint: enable=g-import-not-at-top,unused-variable + except ImportError as e: + print("**** Failed to import TF-TRT ops. This is because the binary was " + "not built with CUDA or TensorRT enabled. ****") + raise e - from tensorflow.python.framework import load_library - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + # TODO(laigd): we should load TF-TRT kernels here as well after removing the + # swig binding. + try: + # pylint: disable=g-import-not-at-top + from tensorflow.python.framework import load_library + from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top - _trt_ops = load_library.load_op_library( - resource_loader.get_path_to_datafile("_trt_ops.so")) -else: - raise RuntimeError("Windows platforms are not supported") + _trt_ops_so = load_library.load_op_library( + resource_loader.get_path_to_datafile("_trt_ops.so")) + except errors.NotFoundError as e: + no_trt_message = ( + "**** Failed to initialize TensorRT. This is either because the " + "TensorRT installation path is not in LD_LIBRARY_PATH, or because " + "you do not have it installed. If not installed, please go to " + "https://developer.nvidia.com/tensorrt to download and install " + "TensorRT ****") + print(no_trt_message) + raise e diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 4a8a4ac7589a4b68b129e8e88ee999e8a2495728..3794929b1df3fa999de6ab218dc2ddfb96e4ac81 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -30,6 +30,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { @@ -725,3 +728,6 @@ tensorflow::Status SegmentGraph( } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index 9a0ccc9aef475edfb0ffb83a2be21d4d4ca0e028..9622ddd593990e93ba1b54e9dfd0052006e20ced 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -namespace tensorflow { +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +namespace tensorflow { namespace tensorrt { namespace segment { @@ -60,4 +62,7 @@ tensorflow::Status SegmentGraph( } // namespace tensorrt } // namespace tensorflow +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 58512d3b09d7c6f523710bc09843c628a5838b53..e11ad2719740d908f93ef580a6b308469365f402 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -26,6 +26,9 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { @@ -265,3 +268,6 @@ TEST_F(SegmentTest, BigIfElse) { } // namespace segment } // namespace tensorrt } // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc index 3bcca99afbff8b84d2dd628ae9211ee94e86af2a..dd3c09d7e42358a1f9e6cc13be6198de58e38963 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include "re2/re2.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h index bcd628b62f0320f7ce9dfe6240316d876f1d5a20..d85875991b79014c4f173d3157ed02e6c96f045c 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/test_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/test_utils.h @@ -16,8 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_ -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index bf111d3a2ee2fbec9151d12bbb6ff7181761c2aa..5213fced1ea9220422245172f5b4a3f584a2a566 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -135,7 +135,7 @@ void TRTInt8Calibrator::setDone() { void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, std::size_t length) { - calibration_table_ = string((const char*)ptr, length); + calibration_table_ = string(static_cast(ptr), length); VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr << " length=" << length; } diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 10587e99624acfb97730bbbd9dfbcde020ffc669..aa70b07f8d79848c362275815004db32cca128be 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -34,7 +34,12 @@ namespace tensorrt { // TRTs pull model for calibration. When TRT implements a means for // a push calibration This class should be updated accordingly +// IInt8EntropyCalibrator2 is prefferred for TRT 5.1+. +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { +#else struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { +#endif public: // Construct a calibrator for future calibration. TRTInt8Calibrator( diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index f454f55f2cb4ee65b97891ae8dd58d809d36f099..6bc842ed5ca7e03018157060a332338cdc926f14 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -26,7 +26,7 @@ namespace tensorrt { void Logger::log(Severity severity, const char* msg) { // Suppress info-level messages switch (severity) { -#if NV_TENSORRT_MAJOR >= 5 && NV_TENSORRT_MINOR >= 1 +#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1) case Severity::kVERBOSE: #endif case Severity::kINFO: { // Mark TRT info messages as debug! diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc deleted file mode 100644 index 0a72a88bc740101bcbadb40bfe106a5b8d284bbf..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h" -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace tensorrt { - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::instance() { - static std::shared_ptr instance_(new TRTResourceManager); - return instance_; -} - -std::shared_ptr -tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { - // mutex is held for lookup only. Most instantiations where mutex will be held - // longer will be during op creation and should be ok. - tensorflow::mutex_lock lock(map_mutex_); - auto s = managers_.find(op_name); - if (s == managers_.end()) { - auto it = managers_.emplace( - op_name, std::make_shared(op_name)); - VLOG(1) << "Returning a new manager " << op_name; - return it.first->second; - } - VLOG(1) << "Returning old manager " << op_name; - return s->second; -} - -} // namespace tensorrt -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h b/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h deleted file mode 100644 index 03879ffff2fa724b05cb1919753e4aaa99e2e702..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ -#include - -#include -#include -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { -namespace tensorrt { - -class TRTResourceManager { - TRTResourceManager() = default; - - public: - static std::shared_ptr instance(); - // returns a manager for given op, if it doesn't exists it creates one - std::shared_ptr getManager(const string& op_name); - - private: - std::unordered_map> - managers_; - tensorflow::mutex map_mutex_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc index 37f7fe99fbb2b9e121953fc0de211db1bbf34b7a..2e553079b19a3e5d0739cc6ac79a84f3b6a1fc4e 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_resources.cc @@ -48,7 +48,7 @@ Status TRTCalibrationResource::SerializeToString(string* serialized) { calibrator_->waitAndSetDone(); thr_->join(); *serialized = calibrator_->getCalibrationTableAsString(); - if (!serialized->size()) { + if (serialized->empty()) { return tensorflow::errors::Unknown("Calibration table is empty."); } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 02de95141da1a28e59d3155742217efdf163e8dd..7d9e7b9fc1f7ea83d6aa982afb5df097b0bdbf77 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -24,7 +24,7 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") cc_library( name = "tf2xla_supported_ops_lib", @@ -60,6 +60,14 @@ xla_proto_library( ], ) +xla_py_proto_library( + name = "tf2xla_py", + has_services = False, + api_version = 2, + visibility = ["//visibility:public"], + deps = [":tf2xla_proto"], +) + xla_proto_library( name = "host_compute_metadata_proto", srcs = ["host_compute_metadata.proto"], @@ -283,6 +291,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -448,6 +457,7 @@ cc_library( hdrs = [ "dump_graph.h", ], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index c8341a2c6bb66e43fb00cb660726cf5a1979c992..8aa162be47c9181e215de6a2eb660215135ff6eb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -63,6 +63,23 @@ size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const { return Hash64Combine(h, std::hash()(static_cast(ancestor.type))); } +typedef std::tuple + ClusterTuple; + +struct ClusterTupleLessThan { + bool operator()(const ClusterTuple& a, const ClusterTuple& b) const { + if (std::tie(std::get<0>(a), std::get<1>(a)) < + std::tie(std::get<0>(b), std::get<1>(b))) { + return true; + } else if (std::tie(std::get<0>(a), std::get<1>(a)) == + std::tie(std::get<0>(b), std::get<1>(b))) { + return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b)); + } else { + return false; + } + } +}; + // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); @@ -744,9 +761,9 @@ Status Conditional::BuildIfNode(Graph* graph, } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), - predicate_.index, - predicate_.node->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index, + predicate_.node->output_type(predicate_.index))); // ... followed by the other inputs. builder.Input(inputs); @@ -1393,16 +1410,30 @@ Status FunctionalizeCond::FunctionalizeInternal() { // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Cluster merge nodes by CondId and AncestorId in order of nesting. - using ClusterPair = std::pair; + // Cluster merge nodes by (CondId, AncestorId, predicate) in order of + // nesting. (CondId, AncestorId) is not enough, e.g. + // pred1 = array_ops.placeholder(dtypes.bool, name='pred1') + // pred2 = array_ops.placeholder(dtypes.bool, name='pred2') + // cond1 = control_flow_ops.cond(pred1, ...) + // cond2 = control_flow_ops.cond(pred2, ...) + // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2) + // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2) + // cond3 and cond4 have the same (CondId, AncestorId), but they should not + // be merged into one "If" node (because they have different predicates). std::deque> merge_clusters; - std::map merge_cluster_index; + std::map merge_cluster_index; for (Node* merge : merge_order) { auto cond_id = state_map_.LookupCondId(merge); if (state_map_.IsDead(cond_id)) continue; - ClusterPair key = - std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto predicate = merge_to_predicate_.find(merge); + if (predicate == merge_to_predicate_.end()) { + return errors::Internal("Cannot find predicate for Merge node ", + merge->name()); + } + + ClusterTuple key = std::make_tuple( + cond_id, state_map_.LookupAncestorId(merge), predicate->second); auto idx = merge_cluster_index.find(key); if (idx == merge_cluster_index.end()) { merge_cluster_index[key] = merge_clusters.size(); @@ -1422,7 +1453,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { &state_map_); for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); TF_RETURN_IF_ERROR( - cond.BuildAndReplace(graph_, library_, &merge_to_predicate_)); + cond.BuildAndReplace(graph_, library_, &merge_to_replacement_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 69353fe87d833fba2c8766ed185481f2238a190d..343568b2392595a2347bde41f0a2e2559fb1de19 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -107,11 +107,13 @@ tf_kernel_library( "xla_pad_op.cc", "xla_reduce_op.cc", "xla_select_and_scatter_op.cc", + "xla_self_adjoint_eig_op.cc", ], hdrs = [ "index_ops.h", "shape_util.h", ], + tags = ["optonly"], deps = [ ":conv_op_helpers", ":if_op", @@ -143,8 +145,8 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/client/lib:quantize", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:bitwise_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", "//tensorflow/core:data_flow_ops_op_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index c2b4c28d1566f5429c5d8109db94af0c3762b131..a99c6ee4431852166eec0a71bb7ad74fd5c135d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -112,9 +113,12 @@ class CategoricalOp : public XlaOpKernel { xla::PrimitiveType type, XlaOpKernelContext* ctx) { xla::XlaBuilder* builder = ctx->builder(); - auto uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. + auto uniforms = xla::RngUniform( + xla::MinPositiveNormalValue(builder, type), + xla::One(builder, uniform_shape.element_type()), uniform_shape); return xla::Log(-xla::Log(uniforms)); } @@ -143,9 +147,13 @@ class StatelessCategoricalOp : public CategoricalOp { if (uniform_shape.element_type() == xla::BF16) { uniform_shape.set_element_type(xla::F32); } + // We want a number in (0, 1) rather than [0, 1) or (0, 1]: + // * log(-log(0)) is ∞. + // * log(-log(1)) is -∞. auto uniforms = xla::StatelessRngUniform( - {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), - XlaHelpers::One(builder, DT_FLOAT)); + {seed0, seed1}, uniform_shape, + xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), + xla::One(builder, uniform_shape.element_type())); return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 5f99b24e221ba6c926032ef7a1b4bf1e92df7a68..e8b270c67a23b876612ab1dba92a8ae7a46a392d 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -203,7 +203,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, absl::Span dilations, const std::vector& strides, - Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, + absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, out_backprop_tensor_shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); @@ -212,8 +213,8 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, - out_backprop_tensor_shape, dilations, strides, padding, - /*explicit_paddings=*/{}, data_format, dims); + out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings, + data_format, dims); } } // anonymous namespace @@ -227,10 +228,9 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); - // TODO(reedwm): Support explicit padding. if (attrs.padding == EXPLICIT) { - return errors::Unimplemented( - "XLA does not yet support Conv2D with explicit padding."); + TF_RETURN_IF_ERROR( + ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } string data_format; @@ -303,6 +303,11 @@ xla::StatusOr MakeXlaForwardConvOp(StringPiece /*type_string*/, window_strides[i] = attrs.strides.at(dim); rhs_dilation[i] = attrs.dilations.at(dim); + if (attrs.padding == EXPLICIT) { + padding[i] = {attrs.explicit_paddings.at(dim * 2), + attrs.explicit_paddings.at(dim * 2 + 1)}; + } + int64 unused_output_size; TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( input_shape.dimensions(dim), filter_shape.dimensions(i), @@ -337,7 +342,7 @@ xla::StatusOr MakeXlaBackpropInputConvOp( TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, - attrs.data_format, &dims)); + attrs.data_format, &dims, attrs.explicit_paddings)); // The input gradients are computed by a convolution of the output // gradients and the filter, with some appropriate padding. See the @@ -420,7 +425,7 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( type_string, attrs.num_spatial_dims, activations_shape, expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, - attrs.padding, attrs.data_format, &dims)); + attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings)); // The activations (inputs) form the LHS of the convolution. // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] @@ -469,6 +474,8 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. @@ -495,6 +502,8 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // We apply negative padding in this case. const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + // + For the EXPLICIT padding, we pad the top/left side with the explicit + // padding and pad the bottom/right side with the remaining space. // + For the VALID padding, we don't pad anything on the top/left side // and pad the bottom/right side with the remaining space. // + For the SAME padding, we pad top/left side the same as bottom/right @@ -503,12 +512,12 @@ xla::StatusOr MakeXlaBackpropFilterConvOp( // In addition, if the padded input size is smaller than the input size, // we need to ignore some training elements of the input. We do this by // applying negative padding on the right/bottom. - const int64 pad_before = - attrs.padding == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - + const int64 pad_before = attrs.padding == Padding::EXPLICIT + ? attrs.explicit_paddings[2 * dim] + : attrs.padding == Padding::SAME + ? std::max(pad_total / 2, 0) + : 0; padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = attrs.dilations[dim]; } // Besides padding the input, we will also expand output_rows to diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 6e1b70a47850ae5c05939f8dfb7ec129c031df21..d893eca7f9ba07dded76eb215af4779080fa66b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -47,6 +47,7 @@ struct ConvOpAttrs { std::vector dilations; std::vector strides; Padding padding; + std::vector explicit_paddings; TensorFormat data_format; }; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index b96d45316f626e678a64392a4315979eeeb6e83c..d19d48e5dd95962fe4a4e4026eaf6b06b7898564 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -134,14 +135,15 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Unlike the bilinear kernel, which is triangular, the nearest neighbor @@ -153,11 +155,12 @@ xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, int64 n) { // to the right (because an existing non TPU kernel // for nearest neighbor resize already chose to default to the right, // so we want to be consistent). -xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { +xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, int64 n) { std::vector kernel(n * 2 - 1, 0.0f); std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f); - return xla::ConstantR1(builder, kernel); + return xla::ConvertElementType(xla::ConstantR1(builder, kernel), type); } // Kernels with more than 16 spatial elements are considered intense and the @@ -165,42 +168,66 @@ xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = xla::Broadcast( - xla::Zero(builder, xla::F32), - {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); + std::vector depthwise_kernel_sizes = { + (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}; + auto depthwise_kernel = + xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]), + depthwise_kernel_sizes, /*broadcast_dimensions=*/{1}); - return xla::Mul( - xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[1]), - /*broadcast_dimensions=*/{1}), - make_kernel_func(builder, kernel_size[0]), - /*broadcast_dimensions=*/{0}); + return xla::Mul(depthwise_kernel, + make_kernel_func(builder, type, kernel_size[0]), + /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder, + xla::PrimitiveType type, absl::Span kernel_size, int64 channels, int64 dim, bool is_kernel_bilinear) { auto make_kernel_func = is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel; - auto depthwise_kernel = - xla::Broadcast(xla::Zero(builder, xla::F32), - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); - return xla::Add(depthwise_kernel, make_kernel_func(builder, kernel_size[dim]), - /*broadcast_dimensions=*/{dim}); + std::vector depthwise_kernel_sizes = { + dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}; + return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]), + depthwise_kernel_sizes, + /*broadcast_dimensions=*/{dim}); +} + +xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder, + const xla::XlaOp& input, + int32 spatial_dimensions_offset, + absl::Span in_size, + absl::Span out_size) { + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + auto broadcast_shape_or_status = builder->GetShape(input); + if (!broadcast_shape_or_status.ok()) { + return builder->ReportError(broadcast_shape_or_status.status()); + } + xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie(); + for (int32 i = 0; i < in_size.size(); ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + broadcast_shape.set_dimensions(spatial_dimensions_offset + i, + out_size[i]); + } + } + return xla::BroadcastInDim(input, broadcast_shape.dimensions(), + /*broadcast_dimensions=*/{0, 1, 2, 3}); } xla::XlaOp ResizeUsingDilationAndConvolution( - xla::XlaBuilder* builder, const xla::XlaOp& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels, const bool align_corners, - bool is_kernel_bilinear) { + xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span out_size, const int64 channels, + const bool align_corners, bool is_kernel_bilinear) { // Picture for a 1x3 to 1x4 bilinear resize: // stride = 2, kernel size = 3 // Input: @@ -287,7 +314,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); output = xla::ConvGeneralDilated(input_data, kernel, dims.stride, @@ -299,7 +326,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); output = xla::ConvGeneralDilated( input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ @@ -308,7 +335,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution( /*rhs_dilation=*/{1, 1}, dimension_numbers, /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ @@ -320,19 +347,14 @@ xla::XlaOp ResizeUsingDilationAndConvolution( // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && out_size[i] > 1) { - output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); - } - } - return output; + return BroadcastSpatialDimensions( + builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size); } xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( - xla::XlaBuilder* builder, const xla::XlaOp& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels, + xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type, + const int num_spatial_dims, absl::Span in_size, + absl::Span grad_size, const int64 channels, const bool align_corners, bool is_kernel_bilinear) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); @@ -353,19 +375,14 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { - xla::XlaOp kernel = MakeGeneralResizeKernel(builder, dims.kernel_size, + xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size, channels, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a size == 1 // dimension to a size > 1 dimension. This has the effect of summing the // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = - xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), - /*broadcast_dimensions=*/{i}); - } - } + kernel = BroadcastSpatialDimensions( + builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size); output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, @@ -377,22 +394,22 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 0, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear); xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim( - builder, dims.kernel_size, channels, 1, is_kernel_bilinear); + builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear); // Broadcast the input kernel where the forward op expanded from a // size == 1 dimension to a size > 1 dimension. This has the effect of // summing the gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { - kernel0 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + kernel0 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, {1}, + {grad_size[0]}); } if (in_size[1] == 1 && grad_size[1] > 1) { - kernel1 = - xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + kernel1 = BroadcastSpatialDimensions(builder, kernel0, + /*spatial_dimensions_offset=*/0, + in_size, grad_size); } output = xla::ConvGeneralDilated( @@ -423,7 +440,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp( } } if (pad_output) { - output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); + output = xla::Pad(output, xla::Zero(builder, type), padding); } return output; } @@ -458,6 +475,7 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, const int num_spatial_dims = 2; xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -475,8 +493,11 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } - // Output is always type float. - input = xla::ConvertElementType(input, xla::F32); + // Output is always type float if 'is_kernel_bilinear' is true. + if (is_kernel_bilinear) { + input = xla::ConvertElementType(input, xla::F32); + input_type = xla::F32; + } // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -504,19 +525,19 @@ void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_, std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels, - align_corners_, is_kernel_bilinear); + b, input, input_type, num_spatial_dims, in_size, next_out_size, + channels, align_corners_, is_kernel_bilinear); input = output; in_size = next_out_size; } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels, + b, input, input_type, num_spatial_dims, in_size, out_size, channels, align_corners_, is_kernel_bilinear); in_size = out_size; } @@ -631,19 +652,19 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels, - align_corners_, true); + b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size, + channels, align_corners_, true); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels, + b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels, align_corners_, true); in_size = grad_size; } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 47cf8c6675bc120653c2a5ab6d4b07376dc382ee..39d96e748b3a2a852c03c0dd53ec175f0c66a43a 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -25,9 +25,6 @@ limitations under the License. namespace tensorflow { EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*)); - float* input = static_cast(data[0]); int64 input_size = *static_cast(data[1]); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 90c0ebefb24ec2c4378782e9b15d3f57c33032a4..5a6569c8954d1686dc9d7577a66feb720241ea13 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { namespace { @@ -31,7 +32,10 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, - /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + /*lower=*/lower_, /*unit_diagonal=*/false, + /*transpose_a=*/ + adjoint_ ? xla::TriangularSolveOptions::ADJOINT + : xla::TriangularSolveOptions::NO_TRANSPOSE); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 01b047f732f0e9fb3b45b272e7886e2f8cf4fff4..d6c70d4af1c2e921b70b0869f0163c8481017c7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -279,9 +280,9 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp one = xla::One(b, xla_shape.element_type()); xla::XlaOp min_positive = - XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + xla::MinPositiveNormalValue(b, xla_shape.element_type()); auto uniform = xla::RngUniform(min_positive, one, xla_shape); ctx->SetOutput(0, TruncatedNormal(uniform)); } diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index a95e7adacf194ba6eb33cbeb56abe1a5a2479337..a1c18bed3f94008af8038f32324c79aa5b2abded 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -110,10 +110,16 @@ class ScatterNdOp : public XlaOpKernel { auto updates = context->Input(1); auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + /*indices_are_vectors=*/true, /*combiner=*/Combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.ValueOrDie()); } + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } }; REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 31d4cc131600f360c764ffa02831046c85d846e5..280b68383c28d1b9d88f7b2ac0f8fab47244c05d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -104,7 +104,7 @@ class SizeOp : public XlaOpKernel { for (int64 i = 0; i < rank; ++i) { size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); } - size = xla::ConvertElementType(size, xla::S32); + size = xla::ConvertElementType(size, ctx->output_xla_type(0)); ctx->SetOutput(0, size); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 50653d7b3973b73d580cdeec5d71943b575d7cc9..17f067e0dfcf4f8b360ee6db934df3e373d5fdd1 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -218,8 +218,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); auto uniform = xla::StatelessRngUniform( {seed0, seed1}, xla_shape, - xla::ConstantR0(builder, std::numeric_limits::min()), - xla::ConstantR0(builder, 1.0)); + xla::MinPositiveNormalValue(builder, xla_shape.element_type()), + xla::One(builder, xla_shape.element_type())); auto output = TruncatedNormal(uniform); output = MaybeConvertF32ToBF16(output, dtype_); ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 65020012283d9c5f62e5e2fd11fc2bf1110e019a..8958a48bc79dce91c41ab7d0a5fc0fbb401112ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -69,6 +71,43 @@ class TensorListLengthOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp); +// Creates an empty list with size (leading_dim, *element_shape) if +// element_shape is known at compile time. Otherwise creates one with size +// (leading_dim, 0) which gets initialized later in `GetInitializedList`. +Status CreateZerosList(XlaOpKernelContext* ctx, int element_shape_index, + int64 leading_dim, DataType dtype, xla::XlaOp* list) { + TensorShape list_shape; + list_shape.AddDim(leading_dim); + xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); + TF_ASSIGN_OR_RETURN( + bool is_element_shape_compile_time_const, + element_shape_handle.builder()->IsConstant(element_shape_handle)); + PartialTensorShape partial_element_shape; + if (is_element_shape_compile_time_const) { + TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape( + element_shape_index, &partial_element_shape)); + } + if (is_element_shape_compile_time_const && + partial_element_shape.IsFullyDefined()) { + TensorShape element_shape; + partial_element_shape.AsTensorShape(&element_shape); + list_shape.AppendShape(element_shape); + } else { + // If element_shape is not a compile time constant or if it is not fully + // defined we will have to wait for the first write call to fully allocate + // the array. + // TODO(srbs): We are using element_shape of [0] as a proxy to denote an + // uninitialized list. A better implementation may be to represent the + // list as a 3-tuple containining an explicit "initialized" flag. However, + // we would still need to create a dummy tensor for the first tuple + // element. + list_shape.AddDim(0); + } + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + return Status::OK(); +} + class TensorListReserveOp : public XlaOpKernel { public: explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -76,20 +115,15 @@ class TensorListReserveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); - TensorShape tensor_shape; - tensor_shape.AddDim(num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, CreateZerosList(ctx, 0, num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, num_elements)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, num_elements)})); } private: @@ -110,8 +144,6 @@ class EmptyTensorListOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - TensorShape element_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); OP_REQUIRES( @@ -119,15 +151,13 @@ class EmptyTensorListOp : public XlaOpKernel { errors::InvalidArgument("XLA compilation requires a fixed tensor list " "size. Set the max number of elements.")); - TensorShape tensor_shape; - tensor_shape.AddDim(max_num_elements); - tensor_shape.AppendShape(element_shape); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, + CreateZerosList(ctx, 0, max_num_elements, dtype_, &list)); xla::XlaBuilder* b = ctx->builder(); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), - tensor_shape.dim_sizes()), - xla::ConstantR0(b, 0)})); + 0, xla::Tuple(b, {list, xla::ConstantR0(b, 0)})); } private: @@ -274,6 +304,36 @@ REGISTER_XLA_OP( Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"), TensorListFromTensorOp); +// Returns the 0'th element of `tuple` containing the list tensor if it has been +// initialized already else creates one lazily. This allows lazy initialization +// of the list on the first call to SetItem or PushBack. +Status GetInitializedList(XlaOpKernelContext* ctx, const xla::XlaOp& tuple, + const TensorShape& element_shape, DataType dtype, + xla::XlaOp* list) { + *list = xla::GetTupleElement(tuple, 0); + TensorShape list_shape; + TF_RETURN_IF_ERROR(GetTensorListShape(ctx->builder(), tuple, &list_shape)); + int64 leading_dim = list_shape.dim_size(0); + TensorShape list_element_shape = list_shape; + list_element_shape.RemoveDim(0); + // This checks for the lazy initialization contract set by CreateEmptyList. + // In TensorListReserve if the element_shape is not known at compile time, + // it creates a list with shape [leading_dim, 0]. + if (element_shape != list_element_shape) { + if (list_element_shape.num_elements() != 0) { + return errors::InvalidArgument( + "Invalid shape of value in TensorListSetItem. Expected: ", + list_element_shape.DebugString(), + " Actual: ", element_shape.DebugString()); + } + list_shape = element_shape; + list_shape.InsertDim(0, leading_dim); + *list = xla::Broadcast(XlaHelpers::Zero(ctx->builder(), dtype), + list_shape.dim_sizes()); + } + return Status::OK(); +} + class TensorListSetItemOp : public XlaOpKernel { public: explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -285,7 +345,9 @@ class TensorListSetItemOp : public XlaOpKernel { xla::XlaOp tl = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(2); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); + xla::XlaOp list; + OP_REQUIRES_OK(ctx, GetInitializedList(ctx, tl, elem_shape, dtype_, &list)); + xla::XlaOp index = ctx->Input(1); xla::XlaOp value = ctx->Input(2); @@ -299,8 +361,8 @@ class TensorListSetItemOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), - index + xla::ConstantR0(b, 1)})); + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), + xla::GetTupleElement(tl, 1)})); } private: @@ -319,11 +381,14 @@ class TensorListPushBackOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp tl = ctx->Input(0); + xla::XlaOp list_tuple = ctx->Input(0); TensorShape elem_shape = ctx->InputShape(1); - xla::XlaOp ta = xla::GetTupleElement(tl, 0); - xla::XlaOp index = xla::GetTupleElement(tl, 1); + xla::XlaOp list; + OP_REQUIRES_OK( + ctx, GetInitializedList(ctx, list_tuple, elem_shape, dtype_, &list)); + + xla::XlaOp index = xla::GetTupleElement(list_tuple, 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -336,7 +401,7 @@ class TensorListPushBackOp : public XlaOpKernel { auto update = xla::Reshape(value, slice_shape.dim_sizes()); ctx->SetTensorListOutput( - 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(list, update, start_indices), index + xla::ConstantR0(b, 1)})); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 26d4214099d1d07c1b2e275d783654d9cd948e28..ceb762038009f7a3ff80d9ad4066af43d54a9e34 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -856,15 +856,12 @@ class ResourceApplyAdadelta : public XlaOpKernel { xla::XlaOp grad = ctx->Input(6); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - accum = rho * accum + (one - rho) * xla::Pow(grad, two); - xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * - xla::Pow(accum + epsilon, neg_half) * grad; - accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + accum = rho * accum + (one - rho) * xla::Square(grad); + xla::XlaOp update = + xla::Sqrt(accum_update + epsilon) * xla::Rsqrt(accum + epsilon) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Square(update); var = var - update * lr; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 76793d677ba45f8e863e684a149da684c8ce8787..4ac714306248302242902f20d45d2609ef2c7cd3 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -19,6 +19,7 @@ limitations under the License. // helper. #include "tensorflow/core/kernels/transpose_op.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -128,29 +129,46 @@ class InvertPermutationOp : public XlaOpKernel { errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); - std::vector perm; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); - - int size = perm.size(); + auto e = ctx->InputExpression(0); + auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client()); + OP_REQUIRES_OK(ctx, tensor_or_status.status()); + // If the input is a constant, we also want the output to be a constant. + // Some models rely on the result of InvertPermutation being a constant. + // TODO(b/32495713): Remove this when we can check whether Scatter is + // constant. Right now, we always assume it is non-constant because we don't + // check the embedded computation. + if (tensor_or_status.ValueOrDie().has_value()) { + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); + + int size = perm.size(); + + std::vector output(size); + std::fill_n(output.data(), size, -1); + for (int i = 0; i < size; ++i) { + const int64 d = perm[i]; + OP_REQUIRES(ctx, FastBoundsCheck(d, size), + errors::InvalidArgument(d, " is not between 0 and ", size)); + OP_REQUIRES(ctx, output[d] == -1, + errors::InvalidArgument(d, " is duplicated in the input.")); + output[d] = i; + } - std::vector output(size); - std::fill_n(output.data(), size, -1); - for (int i = 0; i < size; ++i) { - const int64 d = perm[i]; - OP_REQUIRES(ctx, FastBoundsCheck(d, size), - errors::InvalidArgument(d, " is not between 0 and ", size)); - OP_REQUIRES(ctx, output[d] == -1, - errors::InvalidArgument(d, " is duplicated in the input.")); - output[d] = i; + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); + } else { + auto indices = ctx->Input(0); + int size = ctx->InputShape(0).num_elements(); + auto iota = xla::Iota(ctx->builder(), xla::S32, size); + auto result = XlaScatter(iota, iota, indices, + /*indices_are_vectors=*/false, /*combiner=*/{}, + ctx->builder()); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } - - ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; -REGISTER_XLA_OP(Name("InvertPermutation") - .TypeConstraint("T", DT_INT32) - .CompileTimeConstantInput("x"), +REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 4544e03491438d5f21cf986bc952572bd19d548c..62b5cd32da59063f8ce07119fd085f91ec3a1bc4 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -89,8 +89,9 @@ xla::XlaOp Sigmoid(xla::XlaOp x) { } XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); -// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, + xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x))); XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..233ac8e7b455403f8ee65b95b1403ecefdb92c6b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { +namespace { + +class XlaSelfAdjointEigOp : public XlaOpKernel { + public: + explicit XlaSelfAdjointEigOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = + xla::SelfAdjointEig(ctx->Input(0), lower_, max_iter_, epsilon_); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } + + private: + bool lower_; + int32 max_iter_; + float epsilon_; +}; + +class SelfAdjointEigV2Op : public XlaOpKernel { + public: + explicit SelfAdjointEigV2Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape("input"); + int n = input_shape.dim_size(input_shape.dims() - 1); + // This is based on heuristics that approx log(n) sweep updates are needed. + // Note: the heuristics provides no theoretical guarantee, max_iter=100 and + // epsilon should be used to determine exit condition. + int max_iter = 2 * tensorflow::Log2Ceiling(n); + auto result = xla::SelfAdjointEig(ctx->Input(0), true, max_iter, 1e-6); + ctx->SetOutput(0, result.w); + ctx->SetOutput(1, result.v); + } +}; + +REGISTER_XLA_OP(Name("XlaSelfAdjointEig").TypeConstraint("T", kFloatTypes), + XlaSelfAdjointEigOp); +REGISTER_XLA_OP(Name("SelfAdjointEigV2").TypeConstraint("T", kFloatTypes), + SelfAdjointEigV2Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index af641131ed76a8d6a7291c360302fa17c94af014..ccd58071d350e605e0e1f0c2b43643a400e32c2c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -56,6 +56,41 @@ lhs_output: the broadcasted LHS tensor rhs_output: the broadcasted RHS tensor )doc"); +REGISTER_OP("XlaSelfAdjointEig") + .Input("a: T") + .Attr("lower: bool") + .Attr("max_iter: int") + .Attr("epsilon: float") + .Output("w: T") + .Output("v: T") + .SetShapeFn(shape_inference::UnknownShape) + .Attr("T: numbertype") + .Doc(R"doc( +Computes the eigen decomposition of a batch of self-adjoint matrices +(Note: Only real inputs are supported). + +Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +i=0...N-1. + +a: the input tensor. + +lower: a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + +max_iter: maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximatly logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + +epsilon: the tolerance ratio. + +w: The eigenvalues in ascending order, each repeated according to its + multiplicity. +v: The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. +)doc"); + REGISTER_OP("XlaConv") .Input("lhs: T") .Input("rhs: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 345193c936a885e5a9e468979c4b73b5b0c9e5c2..de4710d03a3e69afb04aa68e37961698f0e3a300 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,6 +291,10 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index cf48576ec2746fb29779633275eac4c638b91e45..28a4566c9d284fb8410a2d618f368c4dd2c1d893 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -254,7 +254,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, +Status ConvertGraphToXla(std::unique_ptr graph, + const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { @@ -264,6 +265,19 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); + // Populate arguments with resource variables from the config. The variables + // get turned into inputs and outputs. + for (const tf2xla::Variable& variable : config.variable()) { + XlaCompiler::Argument arg; + arg.type = variable.type(); + arg.kind = XlaCompiler::Argument::kResource; + arg.shape = variable.shape(); + arg.name = variable.node_name(); + arg.resource_kind = XlaResource::kVariable; + arg.initialized = true; + xla_args.push_back(std::move(arg)); + } + // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; @@ -361,7 +375,8 @@ Status ConvertGraphDefToXla(const GraphDef& graph_def, xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); + TF_RETURN_IF_ERROR( + ConvertGraphToXla(std::move(graph), config, client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index 18c9089f5fa0e9792a4763d9bfac4c4e826eb5b2..5627af7452b99da594c1c214d0b556d8d70544d5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -39,6 +39,15 @@ message Fetch { string name = 2; // Optional name for generated code. }; +// Variable represents a resource variable with the given name, shape and type. +message Variable { + string node_name = 1; + string name = + 2; // Optional name for generated code. If empty, node_name will be used. + TensorShapeProto shape = 3; + DataType type = 4; +} + // Config represents configuration information for tf2xla conversion. message Config { // Each feed is a positional input argument for the generated computation. @@ -47,4 +56,6 @@ message Config { // Each fetch is a positional output argument for the generated computation. // The order of each entry matches the order of each output argument. repeated Fetch fetch = 2; + // Each variable is a named input and output of the generated computation. + repeated Variable variable = 3; }; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index c64f78e1a1bcdd40b1c885889ec5fa491cfa1f66..88c03a6056ac6484013c3fd32c9889899b5c15c5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -122,7 +122,12 @@ Status ReplaceArgUsageWithConstNode( for (const auto& iter : const_input_index_to_node) { int arg_index = iter.first; - Node* const_node = g->CopyNode(iter.second); + NodeDef const_def = iter.second->def(); + const_def.set_name(g->NewName(const_def.name())); + Status s; + Node* const_node = g->AddNode(const_def, &s); + TF_RETURN_IF_ERROR(s); + Node* arg_node = arg_nodes[arg_index]; // Collect all usages of the _Arg node. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 9e9c3cecee68aee856141a620f7292f771978acb..28b4744470e7d28863b5f7275f829b9bd59641e1 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -363,5 +363,58 @@ TEST(PropagateConstIntoFunctionalNodes, WhileLoopWithResourceInput) { TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); } +TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { + FunctionLibraryDefinition fld(OpRegistry::Global(), {}); + { + // Cond graph & body graph. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto input = ops::_Arg(scope.WithOpName("arg1"), DT_BOOL, 1); + auto duplicate_name = ops::NoOp(scope.WithOpName("duplicate_name")); + auto ret = ops::_Retval(scope.WithOpName("ret"), pred, 0); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "cond", &cond_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(cond_fdef)); + FunctionDef body_fdef; + TF_ASSERT_OK(GraphToFunctionDef(graph, "body", &body_fdef)); + TF_ASSERT_OK(fld.AddFunctionDef(body_fdef)); + } + Scope scope = Scope::NewRootScope().ExitOnError(); + auto pred = + ops::Const(scope.WithOpName("duplicate_name"), false, TensorShape({})); + auto input = ops::Const(scope.WithOpName("input"), false, TensorShape({})); + NameAttrList cond_fn, body_fn; + cond_fn.set_name("cond"); + body_fn.set_name("body"); + auto while_op = + ops::While(scope.WithOpName("while"), + std::initializer_list{pred, input}, cond_fn, body_fn); + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + TF_EXPECT_OK(PropagateConstIntoFunctionalNodes(&graph, &fld, &fld)); + + // Check that in rewritten body function, the NoOp node still has name + // "duplicate_name", and the copied Const node has name "duplicate_name/_0". + auto node_name_index = graph.BuildNodeNameIndex(); + Node* while_node = node_name_index["while"]; + ASSERT_NE(while_node, nullptr); + TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); + const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); + ASSERT_NE(rewritten_body_fn, nullptr); + std::unordered_map nodes; + for (const NodeDef& node_def : rewritten_body_fn->node_def()) { + nodes[node_def.name()] = node_def; + } + auto noop_def = nodes.find("duplicate_name"); + ASSERT_NE(noop_def, nodes.end()); + EXPECT_EQ(noop_def->second.op(), "NoOp"); + auto const_def = nodes.find("duplicate_name/_0"); + ASSERT_NE(const_def, nodes.end()); + EXPECT_EQ(const_def->second.op(), "Const"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index ddb284966eeb97cc7c9d3ed77fb313e567975e59..5bd0277c051711f2677b90a2679662899521e94a 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -60,8 +60,6 @@ class XlaCompilationAllocator : public Allocator { // buffers, so they get ids to track. bool ShouldAllocateEmptyTensors() override { return true; } - void GetStats(AllocatorStats* stats) override { stats->Clear(); } - private: // Don't run any constructors or destructors for complex objects, // since there is no backing store for the tensor to run them diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 514b156deb9f350813237c31b7657a5b09c800dd..3221ec5b727de1f792cd61b792ee917588d56cf9 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -43,6 +43,8 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -183,9 +185,10 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of which retvals have layout to update. The first element is - // the output index, second element is the new layout. - std::vector> retval_to_update_layout; + // Keeps track of the layout of each retval. If a retval is not in this list, + // a descending layout is used. The first element is the output index, second + // element is the new layout. + std::vector> retval_index_and_layout; for (int i = 0; i < retvals.size(); ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; @@ -214,7 +217,7 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( output.shape, output.type)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + retval_index_and_layout.emplace_back(elems.size(), shape.layout()); } else if (it != retval_cores.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); @@ -287,6 +290,11 @@ Status BuildComputation( // Ensures the correct sharding is applied to the output. handle = identity_op(handle); + // Set layout of the retval to device representation layout. + if (resource->representation_shape().has_value()) { + retval_index_and_layout.emplace_back( + elems.size(), resource->representation_shape()->layout()); + } elems.push_back(handle); } } @@ -316,15 +324,15 @@ Status BuildComputation( computation->GetProgramShape()); *output_shape = program_shape.result(); // Update the output layout to the layout of retval. - for (auto& update : retval_to_update_layout) { + for (auto& index_and_layout : retval_index_and_layout) { if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = update.second; + *output_shape->mutable_layout() = index_and_layout.second; continue; } - xla::Shape* output_sub_shape = - xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); - *output_sub_shape->mutable_layout() = update.second; + xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( + output_shape, {index_and_layout.first}); + *output_sub_shape->mutable_layout() = index_and_layout.second; } return Status::OK(); } @@ -1108,8 +1116,17 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); std::vector retvals = context->retvals(); if (options.resolve_compile_time_constants) { - TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( - client(), absl::Span(retvals))); + Status status = ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals)); + + // If the HloEvaluator has not implemented an expression, just evaluate it + // at runtime. + if (status.code() == error::UNIMPLEMENTED) { + ConvertConstantsToExpressions(&builder, + absl::Span(retvals)); + } else { + TF_RETURN_IF_ERROR(status); + } } else { ConvertConstantsToExpressions(&builder, absl::Span(retvals)); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 492010f7317d32a8a620147cd2cd9356d4f13fde..b31137867d738944eaaa73e142ad8538ec6b854a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -277,6 +277,97 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } +// Tests that the compiler can correctly propagate the layout assigned by +// shape_representation_fn_ to return types. +TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + options.shape_representation_fn = + [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + +// The layout of resource variable shouldn't change after transpose +TEST_F(XlaCompilerTest, TransposeVariables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto transposed_read = ops::Transpose(scope, read, {1, 0}); + auto reshape = ops::Reshape(scope, transposed_read, {2, 3}); + auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 3}); + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose", + std::move(graph), args, &result)); + xla::Shape transposed = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0}); + // Check that the return shapes are correctly tranposed. + EXPECT_EQ(result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); +} + // Tests that the compiler doesn't reorder the parameters. TEST_F(XlaCompilerTest, MixedOrderArguments) { for (bool swap_order : {false, true}) { diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 6139bf3cea0790c2697130a993e92be96c81848b..3f787fd86c9f7366a7728dcf146a3797ba672bc3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -76,7 +76,7 @@ XlaResource* XlaContext::AddResource(std::unique_ptr resource) { } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { - return LookupOrCreate(type, &max_func_, [this, type] { + return LookupOrCreate(type, &max_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; xla::XlaBuilder b("max<" + type_string + ">"); @@ -92,7 +92,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { - return LookupOrCreate(type, &min_func_, [this, type] { + return LookupOrCreate(type, &min_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; xla::XlaBuilder b("min<" + type_string + ">"); @@ -108,7 +108,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { - return LookupOrCreate(type, &add_func_, [this, type] { + return LookupOrCreate(type, &add_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; xla::XlaBuilder b("add<" + type_string + ">"); @@ -124,7 +124,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { } const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { - return LookupOrCreate(type, &mul_func_, [this, type] { + return LookupOrCreate(type, &mul_func_, [type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; xla::XlaBuilder b("mul<" + type_string + ">"); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 04a5d934064a9083a41cc210b48df65bbc862fff..7bb1ad27467a5b281626de4203169e575288f9ee 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -81,61 +81,27 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return Status::OK(); } -template -static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { - Tensor linspace(DataTypeToEnum::v(), shape); - auto linspace_flat = linspace.flat(); - for (int64 i = 0; i < depth; ++i) { - linspace_flat(i) = i; - } - return linspace; -} - Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::XlaOp& indices, const xla::XlaOp& on_value, const xla::XlaOp& off_value, xla::XlaOp* one_hot) { - const int indices_dims = indices_shape.dims(); - const int output_dims = indices_dims + 1; - - TensorShape output_shape = indices_shape; - output_shape.InsertDim(axis, depth); - - // Build a Tensor populated with values 0, 1, 2, ... depth. - std::vector linspace_dims(output_dims, 1); - linspace_dims[axis] = depth; - TensorShape linspace_shape(linspace_dims); - Tensor linspace; - switch (index_type) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, depth); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(index_type)); - } - - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = xla::Eq( - indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); + + TensorShape output_shape = indices_shape; + output_shape.InsertDim(axis, depth); + xla::Shape iota_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(index_type, output_shape, &iota_shape)); // Selects the user-provided off_value and on_value values. - *one_hot = xla::Select(one_hot_bool, - xla::Broadcast(on_value, output_shape.dim_sizes()), - xla::Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select( + xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 78bc2c94425e00c2b26058daf609d71f1853664e..ee11f3a3de658c7e5108605122b84fbc3e1cd963 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -319,6 +319,27 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } +Status XlaOpKernelContext::ConstantInputAsPartialShape( + int index, PartialTensorShape* shape) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + // If `literal` is a scalar it's value must be -1. + if (literal.shape().rank() == 0) { + int64 shape_val; + TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); + if (shape_val != -1) { + return errors::InvalidArgument( + "Cannot convert value to PartialTensorShape: ", shape_val); + } + *shape = PartialTensorShape(); // Shape with unknown rank. + return Status::OK(); + } + std::vector dims; + TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims)); + *shape = PartialTensorShape(dims); + return Status::OK(); +} + Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { @@ -447,6 +468,16 @@ void XlaOpKernelContext::SetOutputExpression(int index, } } +xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { SetOutputExpression( index, @@ -503,6 +534,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, handle = xla::Reshape(handle, xla::AsInt64Slice(representation_shape.dimensions())); } + variable->SetRepresentationShape(representation_shape); return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index e44415f60bff82fb92d0cf4ec81935564a2f083a..cc2d5e8de3eb020ba41dfed7d730b48cd0534b4c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -138,6 +138,10 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. @@ -155,6 +159,11 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } + // Returns the type of output `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType output_xla_type(int index); + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 736588bb8b89ba756cdce77eeebff8d1fcf4774c..ab3a5bdd9bc580c16d65d35c3be3ba8204511f83 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -86,6 +86,12 @@ class XlaResource { // variables have new values that need to be written back. const xla::XlaOp& initial_value() const { return initial_value_; } + // An xla shape that indicates how this resource variable is represented on + // device. + const absl::optional& representation_shape() const { + return representation_shape_; + } + // A variable is initialized if it has a value. bool initialized() const { return value_.valid(); } @@ -100,6 +106,11 @@ class XlaResource { // Sets the current value of the resource to an all-zero value. Status SetZeroValue(xla::XlaBuilder* builder); + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator @@ -160,6 +171,10 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; + // An xla shape that indicates how this resource variable is represented on + // device. + absl::optional representation_shape_; + int64 max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 636e5ef721f58c009566c10a653d09a7667619c0..ee6f7d5956ede4af99498ca0df5de47150cc5e4d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -150,8 +150,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", "//tensorflow/stream_executor/lib", ], ) @@ -194,7 +192,7 @@ cc_library( ":types", ":util", "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", ], ) @@ -833,7 +831,6 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b30ab84240286fe4eb145fc893ba3f3f7ab26d00..c5dea5f18030f2d226c86e3408ea85b2b5989728 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -65,7 +65,6 @@ cc_library( "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:lib", ], ) @@ -231,6 +230,7 @@ cc_library( deps = [ ":arithmetic", ":constants", + ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -376,6 +376,7 @@ cc_library( srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ + ":comparators", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -451,42 +452,48 @@ cc_library( ) cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], deps = [ - "//tensorflow/compiler/xla:literal", + ":arithmetic", + ":comparators", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = [ - "enable_for_xla_interpreter", - "noasan", # sometimes times out, http://b/78650012 + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", ], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], deps = [ - ":math", + ":arithmetic", + ":constants", ":matrix", - ":triangular_solve", + ":self_adjoint_eig", "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc index 414bd1494cd32f32a5c37e84119de930678a776b..bb41f9932d1cc62b62d37fea2c10fbfeaa0bd15e 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/client/lib/cholesky.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -194,12 +193,12 @@ XlaOp Cholesky(XlaOp a, int64 block_size, // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size); + auto update = + TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 1ada7b4a964ccf7ca400b937abbe425bef083468..6bd56a8df0a5d0417f747a158664ed0daa8a7b40 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -80,6 +80,24 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + std::numeric_limits::min()); + case BF16: + return ConstantR0(builder, bfloat16::min_positive_normal()); + case F32: + return ConstantR0(builder, std::numeric_limits::min()); + case F64: + return ConstantR0(builder, std::numeric_limits::min()); + default: + return builder->ReportError( + InvalidArgument("Invalid type for MinPositiveNormalValue (%s).", + PrimitiveType_Name(type))); + } +} + XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); } @@ -100,4 +118,28 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { } } +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + switch (type) { + case F16: + return ConstantR0( + builder, Eigen::NumTraits::quiet_NaN()); + case BF16: + return ConstantR0( + builder, bfloat16(std::numeric_limits::quiet_NaN())); + case F32: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + case F64: + return ConstantR0(builder, + std::numeric_limits::quiet_NaN()); + default: + return InvalidArgument( + "Operand to NanValue was %s, but must be a real-valued " + "floating-point type.", + PrimitiveType_Name(type)); + } + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 4e5310a380e8bda15348dae2cbb0ea9e2c381bcb..47b8f1b44ffa12b2b15be0e865d693a709962e6e 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -135,6 +135,9 @@ XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); // point type, this is equal to -MaxFiniteValue(). XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns the minimum positive normal value for floating-point type `type`. +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); + // Returns the maximum representable finite or infinite value for 'type'. // Returns 'inf' for floating-point types. XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); @@ -142,6 +145,9 @@ XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); // Returns the maximum representable finite value for 'type'. XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); +// Returns a nan for the given type. Only valid for real-valued fp types. +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc index f4320f65c1f76d4d4c384110b39d6606773aaf01..180175b7495b32250af8ae77c8c7fba804703885 100644 --- a/tensorflow/compiler/xla/client/lib/constants_test.cc +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -155,5 +155,12 @@ XLA_TEST_F(ConstantsTest, MaxValueF32) { {}); } +XLA_TEST_F(ConstantsTest, NanValueF32) { + XlaBuilder builder(TestName()); + NanValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::quiet_NaN(), + {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 14891206855725f1ba71bda9f92134d7c7eb9217..19d98d100191fcba590d64c643d76b2bc5d5e5c5 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -79,6 +79,34 @@ XlaOp IsNan(XlaOp operand) { }); } +XlaOp IsNegZero(XlaOp operand) { + auto& b = *operand.builder(); + return b.ReportErrorOrReturn([&]() -> StatusOr { + TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); + + // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0 + // (sign bit on, all other bits off). + switch (shape.element_type()) { + case F64: + return Eq(BitcastConvertType(operand, U64), + ConstantR0WithType(&b, U64, uint64{1} << 63)); + case F32: + return Eq(BitcastConvertType(operand, U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + case F16: + case BF16: + // Not all XLA backends handle U16 well, so we convert to F32/U32. + // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for + // backends that *do* support it. + return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32), + ConstantR0WithType(&b, U32, uint32{1} << 31)); + default: + LOG(FATAL) << "Expected real fp type."; + } + }); +} + XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 907571c9a3ec65b0be0087ad4837c842a0bdcc79..b036fa299d92988439dfecbb5415865071d5577d 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -28,6 +28,11 @@ XlaOp IsNegInf(XlaOp operand); XlaOp IsInf(XlaOp operand); XlaOp IsNan(XlaOp operand); +// Determines whether operand is equal to -0. +// +// Raises an error for integral or complex values. +XlaOp IsNegZero(XlaOp operand); + // Returns the next number after 'from' in the direction of 'to' the same way // std::nextafter(from, to) would. XlaOp NextAfter(XlaOp from, XlaOp to); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 364ac5876abbec825834081518a6dfda84356048..bdfb0575f573716b54cf9116d155d8a3a55056e8 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -88,6 +88,22 @@ class MathTypedTest : public MathTest { {false, false, false, false, false, false, false, true, true})); ComputeAndCompareLiteral(&b, expected, {}); } + + void TestIsNegZero() { + SetFastMathDisabled(true); + XlaBuilder b(TestName()); + T inf(std::numeric_limits::infinity()); + T nan(std::numeric_limits::quiet_NaN()); + IsNegZero(AddParam( + LiteralUtil::CreateR1({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}), + &b)); + + ComputeAndCompareLiteral( + &b, + LiteralUtil::CreateR1( + {true, false, false, false, false, false, false}), + {}, error_spec_); + } }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. @@ -102,6 +118,7 @@ TYPED_TEST_CASE(MathTypedTest, TestTypes); XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); } XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); } XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); } +XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); } // Check that certain ops only support real, floating-point inputs. // diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc index a5aea96090c59c78d20cfc10a4bd6b312be592c1..a055a8e625c680cf5232896c95cd35b78cb172bc 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.cc +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -45,7 +46,7 @@ XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, return ConvertElementType(indicator, type); } -XlaOp GetMatrixDiagonal(XlaOp x) { +XlaOp GetMatrixDiagonal(XlaOp x, int k) { XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); @@ -53,10 +54,13 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); + + auto offset = ConstantR0WithType(builder, S32, k); + absl::Span major_dims = AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); + auto a = Iota(builder, S32, n); + auto b = Iota(builder, S32, m) + offset; auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); auto mask = Broadcast(indicator, major_dims); @@ -66,9 +70,21 @@ XlaOp GetMatrixDiagonal(XlaOp x) { primitive_util::IsIntegralType(shape.element_type()) ? CreateScalarOrComputation(shape.element_type(), builder) : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + // k == 0, we can save one slice op. + if (k == 0) { + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + } else if (k > 0) { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 2}); + return SliceInMinorDims(result, {std::min(k, n)}, + {std::min(m + k, n)}); + } else { + auto result = Reduce(Select(mask, x, Zeros(builder, shape)), + ScalarLike(x, 0), reducer, {n_dims - 1}); + return SliceInMinorDims(result, {std::min(-k, m)}, + {std::min(m, n - k)}); + } }); } @@ -336,4 +352,5 @@ XlaOp TransposeInMinorDims(XlaOp x) { XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { return transpose ? TransposeInMinorDims(x) : x; } + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h index 491f1eab4cbffbbf9df70d4c35a61351df3e98aa..60c41ec45a086726086dac7227fc432a9c62d0c8 100644 --- a/tensorflow/compiler/xla/client/lib/matrix.h +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -30,10 +30,15 @@ namespace xla { // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); -// Get the diagonals of the last two dimensions. If 'x' has shape -// [..., M, N], then the output has shape [..., min(M, N)], containing the -// diagonal elements (i.e., with indices [..., i, i]). -XlaOp GetMatrixDiagonal(XlaOp x); +// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the +// main diagonal, and k<0 for diagonals below the main diagonal. +// +// If 'x' has shape [..., M, N] +// If k >= 0: then the output has shape [..., min(M, N - k)], containing the +// diagonal elements (i.e., with indices [..., i, i + k]). +// If k < 0: then the output has shape [..., min(M + k, N)], containing the +// diagonal elements (i.e., with indices [..., i - k, i]). +XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); // Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal // and false above that diagonal. diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc index 79cf529ee94b044ee0af788522200cd28c778997..a93fc2ccb92912a10b9b6c2192b81cd73566f2a0 100644 --- a/tensorflow/compiler/xla/client/lib/matrix_test.cc +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { @@ -54,13 +53,24 @@ void MatrixTest::TestMatrixDiagonal() { XlaBuilder builder("GetMatrixDiagonal"); Array3D input(2, 3, 4); input.FillIota(0); + std::map> k_and_expected = { + {0, {{0, 5, 10}, {12, 17, 22}}}, + {1, {{1, 6, 11}, {13, 18, 23}}}, + {2, {{2, 7}, {14, 19}}}, + {3, {{3}, {15}}}, + {4, {{}, {}}}, + {-1, {{4, 9}, {16, 21}}}, + {-2, {{8}, {20}}}, + {-3, {{}, {}}}, + {-4, {{}, {}}}, + }; + for (const auto& kv : k_and_expected) { + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a, kv.first); - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - GetMatrixDiagonal(a); - Array2D expected({{0, 5, 10}, {12, 17, 22}}); - - ComputeAndCompareR2(&builder, expected, {a_data.get()}); + ComputeAndCompareR2(&builder, kv.second, {a_data.get()}); + } } XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc new file mode 100644 index 0000000000000000000000000000000000000000..546127e4627f1717913d1039be13fd0c655be1a3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -0,0 +1,471 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Jacobi rotation (also known as Givens rotation): +// G = [[ c, s], +// [-s, c]] +// matmul(G_T, G) = I +struct SymmetricSchurDecomposition { + XlaOp c; // cosine. + XlaOp s; // sine. +}; + +// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix +// and the off-diagonal norm of the rotated matrix. After each Jacobi iteration, +// off-diagonal norm is reduced. +struct JacobiUpdate { + XlaOp v; + XlaOp w; +}; + +struct FrobeniusNorms { + XlaOp off_diagonal_norm; + XlaOp total_norm; +}; + +// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n, +// it computes a rotation matrix G = [[c, s], [-s, c]], such that +// G_T * A[[p, q], [p, q]] * G +// is diagonalized. +// +// def sym_schur2x2(A, p, q): +// if np.abs(A[p, q]) > 1e-6: +// tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) +// if tau >= 0: +// t = 1.0 / (tau + np.sqrt(1 + tau ** 2)) +// else: +// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2)) +// c = 1.0 / np.sqrt(1.0 + t ** 2) +// s = t * c +// else: +// c = 1.0 +// s = 0.0 +// return c, s +StatusOr SymmetricShurDecomposition2x2(XlaOp a, + XlaOp p, + XlaOp q, + XlaOp tol) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + + auto zero = ScalarLike(a, 0.0); + auto one = ScalarLike(a, 1.0); + auto two = ScalarLike(a, 2.0); + + auto pqs = DynamicSliceInMinorDims(a, {p, q}, {1, 1}); + + auto ps = DynamicSliceInMinorDims(a, {p, p}, {1, 1}); + auto qs = DynamicSliceInMinorDims(a, {q, q}, {1, 1}); + + auto tau = (qs - ps) / (pqs * two); + auto t_pos = one / (tau + Sqrt(one + Square(tau))); + auto t_neg = -one / (-tau + Sqrt(one + Square(tau))); + auto t = Select(Ge(tau, zero), t_pos, t_neg); + + auto c_temp = Rsqrt(one + Square(t)); + auto s_temp = t * c_temp; + + auto c = Select(Ge(Abs(pqs), tol), c_temp, ZerosLike(c_temp) + one); + auto s = Select(Ge(Abs(pqs), tol), s_temp, ZerosLike(s_temp)); + // Renormalize c and s to compensate for low precision arithmetic, this step + // is redundant if high precision float is used, like float64. + auto rnorm = Rsqrt(Square(c) + Square(s)); + + SymmetricSchurDecomposition schur; + + schur.c = c * rnorm; + schur.s = s * rnorm; + + return schur; +} + +StatusOr Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, + XlaOp tol, int64 n) { + XlaBuilder* builder = jacobi_update.w.builder(); + TF_ASSIGN_OR_RETURN( + SymmetricSchurDecomposition schur, + SymmetricShurDecomposition2x2(jacobi_update.w, p, q, tol)); + + TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(jacobi_update.w)); + const std::vector batch_dims(w_shape.dimensions().begin(), + w_shape.dimensions().end() - 2); + const int64 num_dims = w_shape.rank(); + + auto zero = ScalarLike(p, 0); + + XlaOp c = schur.c; + XlaOp s = schur.s; + + auto slice_p = DynamicSliceInMinorDims(jacobi_update.w, {p, zero}, {1, n}); + auto slice_q = DynamicSliceInMinorDims(jacobi_update.w, {q, zero}, {1, n}); + + auto slice_p_new = c * slice_p - s * slice_q; + auto slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {p, zero}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {q, zero}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.w, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.w, {zero, q}, {n, 1}); + + slice_p_new = c * slice_p - s * slice_q; + slice_q_new = s * slice_p + c * slice_q; + + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {zero, p}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {zero, q}); + + // Zero out a_{pq} explicitly. + std::vector pq_dims(batch_dims.begin(), batch_dims.end()); + pq_dims.push_back(1); + pq_dims.push_back(1); + auto pq_zero = ScalarLike(jacobi_update.w, 0.0); + auto pq_zeros = Broadcast(pq_zero, pq_dims); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {p, q}); + jacobi_update.w = + DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {q, p}); + + slice_p = DynamicSliceInMinorDims(jacobi_update.v, {zero, p}, {n, 1}); + slice_q = DynamicSliceInMinorDims(jacobi_update.v, {zero, q}, {n, 1}); + + std::vector broadcast_dims(batch_dims.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.push_back(num_dims - 1); + + // Renormalize the p-th and q-th columns. This step is redundant if high + // precision floats are used, like 64-bit float. But for 32-bit float, it + // becomes necessary. This step will not increase the overall complexity. + slice_p_new = c * slice_p - s * slice_q; + slice_p_new = Mul( + slice_p_new, + Rsqrt(Reduce(Square(slice_p_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + slice_q_new = s * slice_p + c * slice_q; + slice_q_new = Mul( + slice_q_new, + Rsqrt(Reduce(Square(slice_q_new), pq_zero, + CreateScalarAddComputation(w_shape.element_type(), builder), + {num_dims - 2})), + broadcast_dims); + + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_p_new, {zero, p}); + jacobi_update.v = + DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q}); + + return jacobi_update; +} + +StatusOr ComputeFrobeniusNorms(XlaOp w) { + XlaBuilder* builder = w.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w)); + const int64 num_dims = shape.rank(); + auto frobenius_norm = + Sqrt(Reduce(Square(w), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2, num_dims - 1})); + auto diag = GetMatrixDiagonal(w); + auto diag_square = + Reduce(Square(diag), ScalarLike(w, 0.0), + CreateScalarAddComputation(shape.element_type(), builder), + {num_dims - 2}); + + FrobeniusNorms frobenius_norms; + + frobenius_norms.off_diagonal_norm = + Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0))); + frobenius_norms.total_norm = frobenius_norm; + + return frobenius_norms; +} + +StatusOr> WhileLoopFn( + absl::Span initial_values, // + int matrix_dimension, // + int max_sweep_updates, // + PrimitiveType index_type, // + absl::string_view name, // + XlaBuilder* builder) { + auto while_cond_fn = [&](absl::Span values, + XlaBuilder* cond_builder) -> StatusOr { + auto k = values[0]; + auto max_sweeps = ScalarLike(k, max_sweep_updates); + auto sweep_update_cond = Gt(max_sweeps, k); + + auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + auto tol = norms.total_norm * values[3]; + auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), + xla::ConstantR0(cond_builder, false), + CreateScalarOrComputation(PRED, cond_builder)); + + return And(sweep_update_cond, tol_cond); + }; + + auto while_body_fn = + [&](absl::Span values, + XlaBuilder* body_builder) -> StatusOr> { + auto while_cond_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_cond_builder) -> StatusOr { + auto p = values_inner[0]; + return Lt(p, ScalarLike(p, matrix_dimension - 1)); + }; + + auto while_body_fn_inner = + [&](absl::Span values_inner, + XlaBuilder* inner_body_builder) -> StatusOr> { + auto while_cond_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_cond_builder) -> StatusOr { + auto q = values_innermost[1]; + return Lt(q, ScalarLike(q, matrix_dimension)); + }; + auto while_body_fn_innermost = + [&](absl::Span values_innermost, + XlaBuilder* innermost_body_builder) + -> StatusOr> { + auto p = values_innermost[0]; + auto q = values_innermost[1]; + + JacobiUpdate jacobi_update; + jacobi_update.v = values_innermost[2]; + jacobi_update.w = values_innermost[3]; + + auto tol = values_innermost[4]; + + TF_ASSIGN_OR_RETURN(jacobi_update, + Update(jacobi_update, p, q, tol, matrix_dimension)); + + std::vector updated_values_innermost; + updated_values_innermost.reserve(values_innermost.size()); + + updated_values_innermost.push_back(p); + updated_values_innermost.push_back(q + ScalarLike(q, 1)); + updated_values_innermost.push_back(jacobi_update.v); + updated_values_innermost.push_back(jacobi_update.w); + updated_values_innermost.push_back(tol); + + return updated_values_innermost; + }; + + std::vector values_innermost(5); + auto p = values_inner[0]; + auto q = p + ScalarLike(p, 1); + values_innermost[0] = p; // index p. + values_innermost[1] = q; // index q. + values_innermost[2] = values_inner[1]; // v. + values_innermost[3] = values_inner[2]; // w. + values_innermost[4] = values_inner[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_innermost, + WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost, + values_innermost, absl::StrCat(name, "-Innermost"), + inner_body_builder)); + + std::vector updated_values_inner; + updated_values_inner.reserve(values_inner.size()); + + updated_values_inner.push_back(p + ScalarLike(p, 1)); + updated_values_inner.push_back(values_innermost[2]); + updated_values_inner.push_back(values_innermost[3]); + updated_values_inner.push_back(values_innermost[4]); + return updated_values_inner; + }; + // Indexes. + XlaOp k = values[0]; + + std::vector values_inner(4); + values_inner[0] = ScalarLike(k, 0); // index p. + values_inner[1] = values[1]; // v. + values_inner[2] = values[2]; // w. + values_inner[3] = values[3]; // tol. + TF_ASSIGN_OR_RETURN( + values_inner, + WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner, + absl::StrCat(name, "-Inner"), body_builder)); + + std::vector updated_values; + updated_values.reserve(values_inner.size()); + + updated_values.push_back(k + ScalarLike(k, 1)); + updated_values.push_back(values_inner[1]); + updated_values.push_back(values_inner[2]); + updated_values.push_back(values_inner[3]); + + return updated_values; + }; + std::vector values; + TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn, + initial_values, name, builder)); + + return values; +} + +StatusOr SortByEigenvalues(SelfAdjointEigResult result) { + XlaBuilder* builder = result.v.builder(); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.v)); + const int64 num_dims = shape.rank(); + auto dimensions = shape.dimensions(); + + std::vector broadcast_dims(num_dims - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims[num_dims - 2] = num_dims - 1; + result.w = BroadcastInDim(result.w, dimensions, broadcast_dims); + + XlaOp sort_result = + Sort({result.w, result.v}, + CreateScalarLtComputation( + {shape.element_type(), shape.element_type()}, builder), + num_dims - 1); + result.w = GetMatrixDiagonal(GetTupleElement(sort_result, 0)); + result.v = GetTupleElement(sort_result, 1); + return result; +} + +} // namespace + +// This is the cyclic Jacobi iteration. Please note that the eigenvalues are +// possibly not ordered. +// +// def jacobi(A): +// n, _ = A.shape +// V = np.eye(n) +// frobenius_norm = np.linalg.norm(A) +// diag_norm = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm) +// while off_diag_norm > 1e-6 * frobenius_norm: +// for p in range(n - 1): +// for q in range(p + 1, n): +// c, s = sym_schur2x2(A, p, q) +// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]), +// A[[p, q], :]) +// A[:, [p, q]] = np.matmul(A[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// V[:, [p, q]] = np.matmul(V[:, [p, q]], +// np.array([[c, s], [-s, c]])) +// frobenius_norm_sq = np.linalg.norm(A) +// diag_square_sum = np.linalg.norm(np.diag(A)) +// off_diag_norm = np.sqrt( +// frobenius_norm - diag_norm) * np.sqrt( +// frobenius_norm + diag_norm) +// +// return A, V +// +// TODO(kuny): Implement parallel order Jacobi. +// +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, + float epsilon) { + XlaBuilder* builder = a.builder(); + auto return_error = [&](const Status& status) { + SelfAdjointEigResult result; + result.v = builder->ReportError(status); + result.w = builder->ReportError(status); + return result; + }; + auto shape_with_status = builder->GetShape(a); + if (!shape_with_status.status().ok()) { + return return_error(shape_with_status.status()); + } + Shape a_shape = shape_with_status.ValueOrDie(); + const int64 num_dims = a_shape.rank(); + if (num_dims < 2) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must have rank >= 2: got shape %s.", + a_shape.ToString())); + } + PrimitiveType type = a_shape.element_type(); + if (!primitive_util::IsFloatingPointType(type)) { + return return_error(InvalidArgument( + "Type of the input matrix must be float: got %s.", a_shape.ToString())); + } + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + if (m != n) { + return return_error(InvalidArgument( + "Arguments to Eigen decomposition must be square matrices: got shape " + "(%d, %d).", + m, n)); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + auto tol = ScalarLike(a, epsilon); + + auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); + auto w_init = Triangle(a, lower); + w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init; + + auto output_with_status = WhileLoopFn( + { + Zero(builder, S32), // k + v_init, // v + w_init, // w + tol, // + }, // + n, // + max_iter, // + S32, // + "CyclicJacobi", // + builder); + if (!output_with_status.status().ok()) { + return return_error(output_with_status.status()); + } + + auto output = output_with_status.ValueOrDie(); + + SelfAdjointEigResult result; + result.v = output[1]; + result.w = GetMatrixDiagonal(output[2]); + + return SortByEigenvalues(result).ValueOrDie(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h new file mode 100644 index 0000000000000000000000000000000000000000..2a089891d6a2d80c0c265a3310539b4f1c5db4d5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The eigenvalue decomposition of a symmetric matrix, the original matrix is +// recovered by v * w * v_t. +struct SelfAdjointEigResult { + // The i-th column is the normalized eigenvector corresponding to the + // eigenvalue w[i]. Will return a matrix object if a is a matrix object. + XlaOp v; + // The eigenvalues in ascending order, each repeated according to its + // multiplicity. + XlaOp w; +}; + +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64 max_iter = 100, float epsilon = 1e-6); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8875dff7bfdbd4e133297cef0a6686bfcd9bb6f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +class SelfAdjointEigTest : public ClientLibraryTestBase { + protected: + void SetUp() override { + ClientLibraryTestBase::SetUp(); + batch_3d_4x4_ = Array3D{ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 100, 6}, + {12, 48, 6, 62}, + }, + }; + matrix2d_8x8_ = Array2D{ + {14., 123., 49., 112., 115., 173., 182., 125.}, + {123., 14., 60., 118., 150., 130., 91., 72.}, + {49., 60., 138., 111., 106., 101., 115., 142.}, + {112., 118., 111., 142., 91., 130., 25., 61.}, + {115., 150., 106., 91., 116., 121., 128., 85.}, + {173., 130., 101., 130., 121., 70., 151., 132.}, + {182., 91., 115., 25., 128., 151., 66., 92.}, + {125., 72., 142., 61., 85., 132., 92., 156.}, + }; + low_rank_4x4_ = Array2D{ + // x = [[1, 2, 3, 4], [1, -1, 1, -1]] + // matmul(x.T, x) + {2, 1, 4, 3}, + {1, 5, 5, 9}, + {4, 5, 10, 11}, + {3, 9, 11, 17}, + }; + } + void TearDown() override { ClientLibraryTestBase::TearDown(); } + + Array3D GetUnitMatrix3D(const Array3D& matrix) { + Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); + for (int i = 0; i < matrix.n1(); ++i) { + for (int j = 0; j < matrix.n2(); ++j) { + result({i, j, j}) = 1.0; + } + } + return result; + } + + Array3D ExtractTriangularMatrix(const Array3D& matrix, + bool lower) { + Array3D result(matrix); + for (int i = 0; i < result.n1(); ++i) { + for (int j = 0; j < result.n2(); ++j) { + if (lower) { + for (int k = j + 1; k < result.n3(); ++k) { + result({i, j, k}) = 0.0; + } + } else { + for (int k = 0; k < j; ++k) { + result({i, j, k}) = 0.0; + } + } + } + } + return result; + } + + XlaOp ComputeMatmulVWVt(SelfAdjointEigResult result, XlaBuilder* builder) { + Shape shape = builder->GetShape(result.v).ValueOrDie(); + std::vector out_dims = shape.dimensions(); + std::vector broadcast_dims(shape.rank() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + + broadcast_dims[shape.rank() - 2] = shape.rank() - 1; + auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims)); + return BatchDot(vw, TransposeInMinorDims(result.v), + PrecisionConfig::HIGHEST); + } + + XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { + Shape shape = builder->GetShape(m1).ValueOrDie(); + int64 size = 1; + for (auto d : shape.dimensions()) { + size *= d; + } + return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0), + CreateScalarAddComputation(F32, builder)) / + ConstantR0WithType(builder, F32, size); + } + + Array2D GenerateRandomSymmetricMatrix(int size) { + Array2D result{size, size, 0.0}; + result.FillRandom(10 /* stddev */, 2 /* mean */); + for (int i = 0; i < size; ++i) { + for (int j = 0; j < i; ++j) { + result({j, i}) = result({i, j}); + } + } + return result; + } + + Array3D batch_3d_4x4_; + Array2D matrix2d_8x8_; + Array2D low_rank_4x4_; + Array2D wrong_type_4x4_; +}; + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Lower_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VWVt_EQ_A_Upper_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter( + ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a); + auto result = SelfAdjointEig(a, false); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR3(&builder, batch_3d_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_2x4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR3Parameter(batch_3d_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST); + + ComputeAndCompareR3(&builder, GetUnitMatrix3D(batch_3d_4x4_), + {a_data.get()}, ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(low_rank_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + ComputeMatmulVWVt(result, &builder); + + ComputeAndCompareR2(&builder, low_rank_4x4_, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Eigen_8x8) { + XlaBuilder builder(TestName()); + + // This is computed by numpy.linalg.eigh with float32. + std::vector expected{-182.69205, -116.86245, -105.74489, -9.545369, + 37.81711, 104.732285, 120.29153, 868.00385}; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + Add(result.w, ZerosLike(result.w)); + + ComputeAndCompareR1(&builder, expected, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Test_Orthogonality_8x8) { + XlaBuilder builder(TestName()); + + float expected_vals = 1e-3; + + XlaOp a; + auto a_data = CreateR2Parameter(matrix2d_8x8_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2 + GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8), + BatchDot(TransposeInMinorDims(result.v), result.v), + &builder); + + ComputeAndCompareR0(&builder, expected_vals, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Wrong_Type_Int) { + XlaBuilder builder(TestName()); + + XlaOp a; + auto a_data = CreateR2Parameter(wrong_type_4x4_, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + EXPECT_FALSE(result.v.valid()); + EXPECT_FALSE(result.w.valid()); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_8x8) { + XlaBuilder builder(TestName()); + int size = 8; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_16x16) { + XlaBuilder builder(TestName()); + int size = 16; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_32x32) { + XlaBuilder builder(TestName()); + int size = 32; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_256x256) { + XlaBuilder builder(TestName()); + int size = 256; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +XLA_TEST_F(SelfAdjointEigTest, Various_Size_Random_Matrix_512x512) { + XlaBuilder builder(TestName()); + int size = 512; + Array2D a_val = GenerateRandomSymmetricMatrix(size); + XlaOp a; + auto a_data = CreateR2Parameter(a_val, 0, "a", &builder, &a); + auto result = SelfAdjointEig(a); + GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder); + + ComputeAndCompareR0(&builder, 1e-3, {a_data.get()}, + ErrorSpec(1e-3, 1e-3)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 77145ba7d4c72435450d3e33d57b2507eb84d2fc..d7b33c5af25606c4e7e443027b913f7ca13a013c 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -134,4 +134,31 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, }); } +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { + XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + to_concat.reserve(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + if (i == dim) { + to_concat.push_back(Reshape(index, index_shape.dimensions())); + } else { + to_concat.push_back(Iota(builder, index_shape, i)); + } + } + XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank()); + std::vector slice_sizes(input_shape.rank(), 1); + GatherDimensionNumbers gather_dnums; + gather_dnums.set_index_vector_dim(input_shape.rank()); + for (int64 i = 0; i < input_shape.rank(); ++i) { + gather_dnums.add_collapsed_slice_dims(i); + gather_dnums.add_start_index_map(i); + } + return Gather(input, gather_indices, gather_dnums, slice_sizes); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890..69f98a6f43fa167adf6f77b28645a3460b292633 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -43,6 +43,20 @@ XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, absl::Span starts); +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 8d362119e01006555db0f82d02626175936e1d05..db6ebb9df18372260a64a3e9fd17b0c30b35667d 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -102,5 +102,18 @@ XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } +XLA_TEST_F(SlicingTest, TorchGather) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 0}, {1, 0}}, 1, "index", &builder, &index); + TorchGather(input, index, 1); + + ComputeAndCompareR2(&builder, {{1, 1}, {4, 3}}, + {input_data.get(), index_data.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index e8553a08bb014e790822a14e128686b60b8d6b7c..ddc39f4d874cd3613a763b969091e7e65ff1c783 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -30,7 +31,13 @@ XlaOp TopK(XlaOp input, int64 k) { ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - XlaOp sort_result = Sort(Neg(input), {iota_s32}); + // TODO(b/122298745): Get rid of Neg() and use CreateScalarGtComputation + // once the TPU backend supports the comparison computations. + XlaOp sort_result = + Sort({Neg(input), iota_s32}, + CreateScalarLtComputation({input_shape.element_type(), S32}, + iota_s32.builder()), + last_dim, /*is_stable=*/true); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h deleted file mode 100644 index 50a3b30ebd1c15eb6d2ace4e351cb41f21db7093..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Solves systems of linear equations with lower or upper triangular coefficient -// matrices by forward- or back-substitution. Broadcasting along leading -// dimensions, this routine solves one of the matrix systems -// `op(a) * x = b`, or `x * op(a) = b`, -// for the variable `x` given `a` and `b`, where `op(a)` is either -// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. -// That is, the innermost matrices in the output satisfy a scalar system -// depending on the value of the value of (left_side, transpose_a, conjugate_a) -// according to: -// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, -// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, -// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, -// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, -// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, -// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, -// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, -// where * denotes complex conjugation and where the index `k` is summed over. -// -// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form -// square matrices. If lower is true (false), then the strictly upper (lower) -// triangular part of each innermost matrix in `a` is assumed to be zero and is -// not accessed. -// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a -// tensor of shape `[..., K, M]`. -// `left_side` is a boolean, indicating whether to solve a system of the form -// op(a) * x = b (true) or x * op(a) = b (false). -// `lower` is a boolean, indicating whether the argument `a` is lower-triangular -// (true) or upper-triangular (false). -// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. -// `conjugate_a` is a boolean indicating whether the entries of `a` are complex -// conjugated (independently of whether they are transposed), so that when both -// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. -// -// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no -// blocking is used. -XlaOp TriangularSolve( - XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 128, - PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c00ba26295a30c192fedae48f5aabf78cbd7d831..9b7c01a727a5aa0aebf600584b59aec4d47aa0b4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -299,12 +299,17 @@ XlaComputation XlaBuilder::BuildAndNoteError() { return build_status.ConsumeValueOrDie(); } -StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { +Status XlaBuilder::GetCurrentStatus() const { if (!first_error_.ok()) { string backtrace; first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Status::OK(); +} + +StatusOr XlaBuilder::Build(bool remove_dynamic_dimensions) { + TF_RETURN_IF_ERROR(GetCurrentStatus()); return Build(instructions_.back().id(), remove_dynamic_dimensions); } @@ -318,11 +323,7 @@ StatusOr XlaBuilder::Build(XlaOp root, StatusOr XlaBuilder::Build(int64 root_id, bool remove_dynamic_dimensions) { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } + TF_RETURN_IF_ERROR(GetCurrentStatus()); // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove // all dynamic dimensions before building xla program until we have support in @@ -573,16 +574,6 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, }); } -XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1004,36 +995,6 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { }); } -XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1054,6 +1015,18 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + // If one operand is a scalar, just multiply the two operands. + if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) { + if (dimension_numbers.rhs_batch_dimensions_size() != 0 || + dimension_numbers.lhs_batch_dimensions_size() != 0 || + dimension_numbers.rhs_contracting_dimensions_size() != 0 || + dimension_numbers.lhs_contracting_dimensions_size() != 0) { + return InvalidArgument( + "Dots with scalar operands must have no contracting or batch " + "dimensions"); + } + return xla::Mul(lhs, rhs); + } TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); @@ -1549,147 +1522,6 @@ XlaOp XlaBuilder::CustomCall( }); } -XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); -} - -XlaOp XlaBuilder::Conj(const XlaOp& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::Not(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNot, operand); -} - -XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, - broadcast_dimensions); -} - -XlaOp XlaBuilder::Abs(const XlaOp& operand) { - return UnaryOp(HloOpcode::kAbs, operand); -} - -XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); -} - -XlaOp XlaBuilder::Exp(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExp, operand); -} - -XlaOp XlaBuilder::Expm1(const XlaOp& operand) { - return UnaryOp(HloOpcode::kExpm1, operand); -} - -XlaOp XlaBuilder::Floor(const XlaOp& operand) { - return UnaryOp(HloOpcode::kFloor, operand); -} - -XlaOp XlaBuilder::Ceil(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCeil, operand); -} - -XlaOp XlaBuilder::Round(const XlaOp& operand) { - return UnaryOp(HloOpcode::kRoundNearestAfz, operand); -} - -XlaOp XlaBuilder::Log(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog, operand); -} - -XlaOp XlaBuilder::Log1p(const XlaOp& operand) { - return UnaryOp(HloOpcode::kLog1p, operand); -} - -XlaOp XlaBuilder::Sign(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSign, operand); -} - -XlaOp XlaBuilder::Clz(const XlaOp& operand) { - return UnaryOp(HloOpcode::kClz, operand); -} - -XlaOp XlaBuilder::Cos(const XlaOp& operand) { - return UnaryOp(HloOpcode::kCos, operand); -} - -XlaOp XlaBuilder::Sin(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSin, operand); -} - -XlaOp XlaBuilder::Tanh(const XlaOp& operand) { - return UnaryOp(HloOpcode::kTanh, operand); -} - -XlaOp XlaBuilder::Real(const XlaOp& operand) { - return UnaryOp(HloOpcode::kReal, operand); -} - -XlaOp XlaBuilder::Imag(const XlaOp& operand) { - return UnaryOp(HloOpcode::kImag, operand); -} - -XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { - return UnaryOp(HloOpcode::kIsFinite, operand); -} - XlaOp XlaBuilder::Transpose(const XlaOp& operand, absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1720,36 +1552,146 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } +namespace { +// Switch from a floating point value to a integer value in such a way that when +// using the integer value to compare, we get the same result for normal values, +// and -Nan is treated as the smallest value, and Nan is treated as the largest +// value. +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? numeric_limits::max() - x : x; +// then y is ordered as an int32 such that finite values have the obvious order, +// -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of +// the ordering. +// Note that in order to avoid -x to overflow, we calculate +// numeric_limits::max() - x as unsigned, and then convert back to +// signed. +XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, + int64 bit_width) { + PrimitiveType signed_type; + PrimitiveType unsigned_type; + XlaOp max_value; + switch (bit_width) { + case 16: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S16; + unsigned_type = U16; + break; + case 32: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S32; + unsigned_type = U32; + break; + case 64: + max_value = + ConstantR0(value.builder(), + static_cast(std::numeric_limits::max())); + signed_type = S64; + unsigned_type = U64; + break; + default: + return value.builder()->ReportError( + InvalidArgument("Invalid bit width %lld for Comparator floating " + "point parameter.", + bit_width)); + } + auto signed_value = BitcastConvertType(value, signed_type); + auto unsigned_value = BitcastConvertType(value, unsigned_type); + auto flipped_value = + BitcastConvertType(Sub(max_value, unsigned_value), signed_type); + auto is_negative = + Lt(signed_value, + ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type))); + return Select(is_negative, flipped_value, signed_value); +} +} // namespace + XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + std::vector operands{keys}; + for (const XlaOp& value : values) { + operands.push_back(value); + } + // Build the default less-than comparator (copied from lib/comparators.cc). + // TODO(b/122298745): Remove the deprecated API method so that this code + // duplication can be deleted. + auto b = this->CreateSubBuilder("comparator"); + std::vector operand_types; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + int64 parameter_count = 0; + XlaOp first_lhs_param; + XlaOp first_rhs_param; + + for (auto operand_type : operand_types) { + auto scalar_shape = ShapeUtil::MakeShape(operand_type, {}); + auto lhs_param = + b->Parameter(parameter_count * 2, scalar_shape, + absl::StrCat("p.", parameter_count, ".lhs")); + auto rhs_param = + b->Parameter(parameter_count * 2 + 1, scalar_shape, + absl::StrCat("p.", parameter_count, ".rhs")); + if (parameter_count == 0) { + first_lhs_param = lhs_param; + first_rhs_param = rhs_param; + } + ++parameter_count; + } + if (primitive_util::IsFloatingPointType(operand_types[0])) { + PrimitiveType compare_type = operand_types[0]; + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + first_lhs_param = b->ConvertElementType(first_lhs_param, F32); + first_rhs_param = b->ConvertElementType(first_rhs_param, F32); + } + int64 bit_width = primitive_util::BitWidth(compare_type); + first_lhs_param = + BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width); + first_rhs_param = + BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width); + } + Lt(first_lhs_param, first_rhs_param); + + TF_ASSIGN_OR_RETURN(auto comparator, b->Build()); + return Sort(operands, comparator, dimension, /*is_stable=*/false); + }); +} + +XlaOp XlaBuilder::Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); - operand_shape_ptrs.push_back(&keys_shape); - TF_ASSIGN_OR_RETURN(std::vector values_shapes, - GetOperandShapes(values)); - absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + TF_ASSIGN_OR_RETURN(std::vector operand_shapes, + GetOperandShapes(operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { - TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(operands[0])); dimension = keys_shape.rank() - 1; } instr.add_dimensions(dimension); - std::vector operands{keys}; - operands.insert(operands.end(), values.begin(), values.end()); + AddCalledComputation(comparator, &instr); return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } -XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions) { - return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); -} - XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1775,10 +1717,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::Neg(const XlaOp& operand) { - return UnaryOp(HloOpcode::kNegate, operand); -} - XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return TernaryOp(HloOpcode::kClamp, min, operand, max); @@ -2159,8 +2097,8 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); - b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), - b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); return CrossReplicaSum(operand, computation, replica_groups, /*channel_id=*/absl::nullopt); @@ -2956,32 +2894,38 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs, + broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs, + broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs, + broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs, + broadcast_dimensions); } -XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs, + broadcast_dimensions); } -XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs, + broadcast_dimensions); } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, @@ -3055,6 +2999,29 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, return operand.builder()->Fft(operand, fft_type, fft_length); } +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(const Shape& b_shape, builder->GetShape(b)); + xla::TriangularSolveOptions& options = + *instr.mutable_triangular_solve_options(); + options.set_left_side(left_side); + options.set_lower(lower); + options.set_unit_diagonal(unit_diagonal); + options.set_transpose_a(transpose_a); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape( + a_shape, b_shape, options)); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), + HloOpcode::kTriangularSolve, {a, b}); + }); +} + XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { return builder->Infeed(shape, config); } @@ -3084,78 +3051,96 @@ XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, operand_shapes_with_layout); } -XlaOp Complex(const XlaOp& real, const XlaOp& imag, +XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return real.builder()->Complex(real, imag, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs, + broadcast_dimensions); } -XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } +XlaOp Conj(const XlaOp& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs, + broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs, + broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs, + broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs, + broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs, + broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs, + broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs, + broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->And(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs, + broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs, + broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs, + broadcast_dimensions); } -XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } +XlaOp Not(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNot, operand); +} XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); } XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, @@ -3250,48 +3235,67 @@ XlaOp SelectAndScatterWithGeneralPadding( init_value, scatter); } -XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } +XlaOp Abs(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kAbs, operand); +} -XlaOp Atan2(const XlaOp& y, const XlaOp& x, +XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return y.builder()->Atan2(y, x, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs, + broadcast_dimensions); } -XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } - -XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } - -XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } - -XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } - -XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } - -XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } - -XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } - -XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } - -XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } - -XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } - -XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } - -XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } - -XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } - -XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } +XlaOp Exp(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExp, operand); +} +XlaOp Expm1(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand); +} +XlaOp Floor(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kFloor, operand); +} +XlaOp Ceil(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCeil, operand); +} +XlaOp Round(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} +XlaOp Log(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog, operand); +} +XlaOp Log1p(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand); +} +XlaOp Sign(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSign, operand); +} +XlaOp Clz(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kClz, operand); +} +XlaOp Cos(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kCos, operand); +} +XlaOp Sin(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kSin, operand); +} +XlaOp Tanh(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kTanh, operand); +} +XlaOp Real(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kReal, operand); +} +XlaOp Imag(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kImag, operand); +} XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { - return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); + return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs, + broadcast_dimensions); } XlaOp IsFinite(const XlaOp& operand) { - return operand.builder()->IsFinite(operand); + return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand); } XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { @@ -3302,7 +3306,9 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } +XlaOp Neg(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kNegate, operand); +} XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); @@ -3316,6 +3322,12 @@ XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return keys.builder()->Sort(keys, values, dimension); } +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension, bool is_stable) { + return operands[0].builder()->Sort(operands, comparator, dimension, + is_stable); +} + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index c429035ad0f96928525219a5506df81d64ffef95..b5bdb75998a4a0bfce3607715963c1609bc45ef8 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -56,6 +56,9 @@ class XlaOp { } ~XlaOp() = default; + XlaOp(const XlaOp& other) = default; + XlaOp& operator=(const XlaOp& other) = default; + // Precondition: !IsUninitialized(). // // It's very common to do foo.builder()->bar(). Without this precondition, if @@ -235,6 +238,10 @@ class XlaBuilder { // See also set_die_immediately_on_error(). Status first_error() const { return first_error_; } + // Returns the current status of the builder, complete with the stack trace + // information. + Status GetCurrentStatus() const; + // Returns the shape of the given op. StatusOr GetShape(const XlaOp& op) const; @@ -315,38 +322,6 @@ class XlaBuilder { XlaOp ConstantLiteral(const LiteralSlice& literal); - template - XlaOp ConstantR0(NativeT value); - template - XlaOp ConstantR1(absl::Span values); - XlaOp ConstantR1(const tensorflow::core::Bitmap& values); - template - XlaOp ConstantR2( - std::initializer_list> values); - template - XlaOp ConstantFromArrayWithLayout(const Array& values, - const Layout& layout); - template - XlaOp ConstantFromArray(const Array& values); - template - XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template - XlaOp ConstantR2FromArray2D(const Array2D& values); - template - XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template - XlaOp ConstantR3FromArray3D(const Array3D& values); - template - XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); - template - XlaOp ConstantR4FromArray4D(const Array4D& values); - - template - XlaOp ConstantR1(int64 length, NativeT value); - XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -394,24 +369,6 @@ class XlaBuilder { XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); @@ -476,50 +433,6 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - absl::Span broadcast_dimensions = {}); - - XlaOp Conj(const XlaOp& operand); - - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp Not(const XlaOp& operand); - - XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); @@ -578,44 +491,6 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - XlaOp Abs(const XlaOp& operand); - - XlaOp Atan2(const XlaOp& y, const XlaOp& x, - absl::Span broadcast_dimensions = {}); - - XlaOp Exp(const XlaOp& operand); - - XlaOp Expm1(const XlaOp& operand); - - XlaOp Floor(const XlaOp& operand); - - XlaOp Ceil(const XlaOp& operand); - - XlaOp Round(const XlaOp& operand); - - XlaOp Log(const XlaOp& operand); - - XlaOp Log1p(const XlaOp& operand); - - XlaOp Sign(const XlaOp& operand); - - XlaOp Clz(const XlaOp& operand); - - XlaOp Cos(const XlaOp& operand); - - XlaOp Sin(const XlaOp& operand); - - XlaOp Tanh(const XlaOp& operand); - - XlaOp Real(const XlaOp& operand); - - XlaOp Imag(const XlaOp& operand); - - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - absl::Span broadcast_dimensions = {}); - - XlaOp IsFinite(const XlaOp& operand); - XlaOp Iota(const Shape& shape, int64 iota_dimension); XlaOp Iota(PrimitiveType type, int64 size); @@ -626,14 +501,15 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - XlaOp Neg(const XlaOp& operand); - XlaOp Transpose(const XlaOp& operand, absl::Span permutation); XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); + XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -825,48 +701,6 @@ class XlaBuilder { const Shape& shape, const string& name); friend XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - template - friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); - template - friend XlaOp ConstantR1(XlaBuilder* builder, - absl::Span values); - friend XlaOp ConstantR1(XlaBuilder* builder, - const tensorflow::core::Bitmap& values); - template - friend XlaOp ConstantR2( - XlaBuilder* builder, - std::initializer_list> values); - template - friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout); - template - friend XlaOp ConstantFromArray(XlaBuilder* builder, - const Array& values); - template - friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout); - template - friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values); - template - friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout); - template - friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values); - template - friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout); - template - friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values); - - template - friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); @@ -964,6 +798,9 @@ class XlaBuilder { const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); + friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, @@ -1089,6 +926,9 @@ class XlaBuilder { friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); friend XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension); + friend XlaOp Sort(absl::Span operands, + const XlaComputation& comparator, int64 dimension, + bool is_stable); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1484,6 +1324,32 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves for x in one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// +// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If `lower` is true (false), then the strictly upper +// (lower) triangular part of each innermost matrix in `a` is assumed to be +// zero and is not accessed. +// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a +// tensor of shape `[..., K, M]`. +// * `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// * `lower` is a boolean, indicating whether the argument `a` is +// lower-triangular +// (true) or upper-triangular (false). +// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be +// 1 and not accessed. +// * `transpose_a` indicates which function `op` we use to transform the tensor +// `a`: the identity function, transpose(a), or conjugate(transpose(a)) +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1819,7 +1685,7 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // of keys, in ascending order. // * If the keys have higher rank, the keys are sorted along the provided // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension -// value of 0 will indepenently sort every column, and a dimension value of 1 +// value of 0 will independently sort every column, and a dimension value of 1 // will independently sort each row. If no dimension number is provided, then // the last dimension is chosen by default. // @@ -1829,9 +1695,39 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and tensors with their // corresponding values as the other elements. +ABSL_DEPRECATED("Use form with comparator computation instead") XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); +// Enqueues a sort instruction onto the computation, using 'comparator' for +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. +// If only one operand is provided: +// * If the operand is a rank-1 tensor (an array), the result is a sorted array. +// The resulting sorting order has the property that for all index positions +// i, j with i < j, either +// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or +// comparator(value[i], value[j]) = true. +// * If the operand has higher rank, the operand is sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix), a dimension value +// of 0 will independently sort every column, and a dimension value of 1 will +// independently sort each row. If no dimension number is provided, then the +// last dimension is chosen by default. For the dimension which is sorted, the +// same sorting order applies as in the rank-1 case. +// +// If more than one operand is provided: +// * All operands must be tensors with the same dimensions. The element types of +// the tensors may be different. +// * The result is a tuple that consists of the operands in sorted order (along +// the provided dimension, as above). The same permutation as implied by the +// comparison computation is applied to all operand tensors. When comparing +// two index positions, 'comparator' is called with 2 * n scalar parameters, +// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at +// two index positions. +// Default comparator computations can be found in lib/comparators.h +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64 dimension = -1, bool is_stable = false); + // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -1970,81 +1866,6 @@ XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); // Implementation details below this point. // -template -XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(LiteralUtil::CreateR0(value)); -} - -template -XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(LiteralUtil::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(LiteralUtil::CreateR2(values)); -} - -template -XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, - const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(LiteralUtil::CreateFromArray(values)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { - return ConstantFromArray(values); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { - return ConstantFromArray(values); -} - // Free function template implementations. template diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index a9a91648ac377987e7f226116e11c9c697ace103..43d9ee0d9a5e689676b00e59d7c59bb0f4e37461 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -128,11 +128,6 @@ static void AllocateFlags() { tensorflow::Flag( "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), tensorflow::Flag("xla_hlo_dump_as_html", bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_html), flag_values->xla_hlo_dump_as_html(), @@ -144,13 +139,6 @@ static void AllocateFlags() { flag_values->xla_hlo_graph_sharding_color(), "Assign colors based on sharding assignments when generating the " "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), "HLO modules matching this regex will be dumped to LOG(INFO)."), diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 0f9b591c70d4fd96147958d18bd5fb7dd78a7f3f..230f3b202a4b531c381665471c3856c3feba5a3a 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,7 +77,7 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { } ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( - DeviceAssignment* device_assignment) { + const DeviceAssignment* device_assignment) { device_assignment_ = device_assignment; return *this; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 6f36d11dfb34eb27e79ea4ff797d35f80fb44b27..1e744953bd3be58afba5b81c0e2a8ba26665f9c4 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -74,7 +74,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); ExecutableRunOptions& set_device_assignment( - DeviceAssignment* device_assignment); + const DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; ExecutableRunOptions& set_rng_seed(int rng_seed); @@ -83,7 +83,7 @@ class ExecutableRunOptions { private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; - DeviceAssignment* device_assignment_ = nullptr; + const DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index 267701e9c0e42a21d2cda6238520f6a9692e7e76..d756cd74c98b98a6fda099690d966562bd694e2c 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -25,6 +25,8 @@ upper_tabs: path: /xla/operation_semantics - title: Shapes and layout path: /xla/shapes + - title: Tiled layout + path: /xla/tiled_layout - title: Using AOT compilation path: /xla/tfcompile - heading: Tutorials diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 363fd17b69bfbe54d486e367d9bf5cc0eee4205e..db90d184b5218614ac49363ebf2a7e25fffe44de 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1186,7 +1186,7 @@ if and only if the corresponding input element is finite. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where -$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$ +$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ using the comparison operator of the element type of `operand`. diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/tiled_layout.md similarity index 96% rename from tensorflow/compiler/xla/g3doc/layout_with_tiling.md rename to tensorflow/compiler/xla/g3doc/tiled_layout.md index 5e990851af7495ebd4417e44f1d955fcc14dadf1..21e88ceab6208cdf940826d769fd93713044d5a0 100644 --- a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md +++ b/tensorflow/compiler/xla/g3doc/tiled_layout.md @@ -1,9 +1,7 @@ # Tiled layout -*Note: This doc describes how tiled layout is intended to work. Tiling is being -implemented, but this is an early effort and it is currently not even guaranteed -to get an Unimplemented error if one tries to use tiling - it may be just -silently ignored.* +Caution: Tiled layout is *pre-release* and this describes how it's intended to +work. Errors may be silently ignored.
![](images/xla_array_layout_figure1.png) diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 7e22a32e545e4155545ffcfb9582187eadec3a82..eebd8245abe759b71b3fe732943761325ea04b81 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index e3b5fcd5274881cec31ecf906e3461685f82a1f4..000c4fdc40519214fa9fa721a8987b77b534442b 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -30,7 +30,19 @@ TileProto Tile::ToProto() const { } string Tile::ToString() const { - return absl::StrCat("(", absl::StrJoin(dimensions(), ","), ")"); + std::vector elements; + for (auto dim : dimensions()) { + if (dim >= 0) { + elements.push_back(std::to_string(dim)); + } else { + if (dim == kCombineDimension) { + elements.push_back("*"); + } else { + elements.push_back(absl::StrCat("Invalid value ", dim)); + } + } + } + return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); } /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) { @@ -64,23 +76,43 @@ LayoutProto Layout::ToProto() const { } string Layout::ToString() const { - // TODO(b/119839262): Emit tiles in string. if (format() == SPARSE) { + CHECK_EQ(tiles_size(), 0) << "Sparse layout should not be tiled."; return absl::StrCat("sparse{", max_sparse_elements(), "}"); } else if (format() == DENSE) { - return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), "}"); + string colon_string = tiles().empty() ? "" : "T"; + for (Tile tile : tiles()) { + absl::StrAppend(&colon_string, tile.ToString()); + } + if (element_size_in_bits() != 0) { + absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")"); + } + return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), + colon_string.empty() ? "" : ":", colon_string, "}"); } else { CHECK_EQ(format(), INVALID_FORMAT); return "invalid{}"; } } +bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { + if (lhs.format() != rhs.format() || + lhs.minor_to_major() != rhs.minor_to_major() || + lhs.max_sparse_elements() != rhs.max_sparse_elements()) { + return false; + } + if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) { + return false; + } + if (!ignore_element_size_ && + lhs.element_size_in_bits() != rhs.element_size_in_bits()) { + return false; + } + return true; +} + bool Layout::operator==(const Layout& other) const { - return (other.format() == format() && - other.minor_to_major() == minor_to_major() && - other.element_size_in_bits() == element_size_in_bits() && - other.max_sparse_elements() == max_sparse_elements() && - other.tiles() == tiles()); + return Equal()(*this, other); } std::ostream& operator<<(std::ostream& out, const Tile& tile) { diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index 313368c39e4c976fc481941eb17325101f2ba69a..acc449b781b503142b24ed7229e3559230bb1599 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -55,6 +55,20 @@ class Tile { // Returns the dimensions of the tile. const std::vector& dimensions() const { return dimensions_; } + Tile& add_dimensions(int64 value) { + dimensions_.push_back(value); + return *this; + } + + Tile& clear_dimensions() { + dimensions_.clear(); + return *this; + } + + // This dimension size means the corresponding dimension in the shape is + // combined with the next minor dimension before tiling is applied. + static constexpr int64 kCombineDimension = std::numeric_limits::min(); + private: // The bounds of the tile. std::vector dimensions_; @@ -71,10 +85,12 @@ class Layout { // Constructs a dense tiled layout with the given minor-to-major order and // tiles. - Layout(absl::Span minor_to_major, absl::Span tiles) + Layout(absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits = 0) : format_(DENSE), minor_to_major_(minor_to_major.begin(), minor_to_major.end()), - tiles_(tiles.begin(), tiles.end()) {} + tiles_(tiles.begin(), tiles.end()), + element_size_in_bits_(element_size_in_bits) {} // Construct a shape from a LayoutProto. static Layout CreateFromProto(const LayoutProto& proto); @@ -85,6 +101,37 @@ class Layout { // Returns a human-readable string that represents this layout. string ToString() const; + // Equal is a configurable functor to check the equality of two layouts. + // + // Examples: + // + // - Comparing two layouts ignoring their difference in tiles: + // Equal().IgnoreTiles()(layout1, layout2); + // + // - Comparing two layouts ignoring their difference in tiles and element + // size: + // Equal().IgnoreTiles().IgnoreElementSize()(layout1, layout2); + class Equal { + public: + Equal() = default; + + bool operator()(const Layout& lhs, const Layout& rhs); + + Equal& IgnoreTiles() { + ignore_tiles_ = true; + return *this; + } + + Equal& IgnoreElementSize() { + ignore_element_size_ = true; + return *this; + } + + private: + bool ignore_tiles_ = false; + bool ignore_element_size_ = false; + }; + bool operator==(const Layout& other) const; bool operator!=(const Layout& other) const { return !(*this == other); } @@ -159,7 +206,7 @@ class Layout { element_size_in_bits_ = 0; } - public: + private: // The format of this layout. Format format_ = INVALID_FORMAT; @@ -172,11 +219,11 @@ class Layout { // memory. This field must be zero unless the format is SPARSE. int64 max_sparse_elements_ = 0; - // The number of bits used to store an individual array element. - int64 element_size_in_bits_ = 0; - // The tiles used in tiling-based layout. std::vector tiles_; + + // The number of bits used to store an individual array element. + int64 element_size_in_bits_ = 0; }; std::ostream& operator<<(std::ostream& out, const Tile& Tile); diff --git a/tensorflow/compiler/xla/layout_test.cc b/tensorflow/compiler/xla/layout_test.cc index fb6abd3f6523b978e72b21ec082ae06973e86243..f5d71c553ed2e0cfd5d5945144dd476557582b5f 100644 --- a/tensorflow/compiler/xla/layout_test.cc +++ b/tensorflow/compiler/xla/layout_test.cc @@ -38,10 +38,13 @@ TEST_F(LayoutTest, ToString) { "sparse{123}"); EXPECT_EQ(Layout({4, 5, 6}).ToString(), "{4,5,6}"); EXPECT_EQ(Layout({3, 2, 1, 0}, {Tile({42, 123}), Tile({4, 5})}).ToString(), - "{3,2,1,0}"); + "{3,2,1,0:T(42,123)(4,5)}"); EXPECT_EQ( Layout({1, 0}, {Tile({2, 55})}).set_element_size_in_bits(42).ToString(), - "{1,0}"); + "{1,0:T(2,55)E(42)}"); + EXPECT_EQ( + Layout({1, 0}, {Tile({-2, 55})}).set_element_size_in_bits(42).ToString(), + "{1,0:T(Invalid value -2,55)E(42)}"); } TEST_F(LayoutTest, StreamOut) { @@ -84,6 +87,15 @@ TEST_F(LayoutTest, Equality) { Layout().set_format(SPARSE).set_max_sparse_elements(42)); EXPECT_NE(Layout().set_format(SPARSE).set_max_sparse_elements(42), Layout().set_format(SPARSE).set_max_sparse_elements(24)); + + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {Tile({42, 44})}), Layout({0, 1, 2}))); + EXPECT_TRUE(Layout::Equal().IgnoreTiles()(Layout({0, 1, 2}, {Tile({42, 44})}), + Layout({0, 1, 2}))); + EXPECT_FALSE( + Layout::Equal()(Layout({0, 1, 2}, {}, 32), Layout({0, 1, 2}, {}, 1))); + EXPECT_TRUE(Layout::Equal().IgnoreElementSize()(Layout({0, 1, 2}, {}, 32), + Layout({0, 1, 2}, {}, 1))); } TEST_F(LayoutTest, LayoutToFromProto) { diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 2fe9b56c6bdffb931726f60ab75081361b43ebb4..62314118ca9713a04cb4e3cf6ad261b966d85f15 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -54,12 +54,24 @@ void SetDefaultLayoutToContainer(std::vector* minor_to_major) { } // namespace /* static */ Layout LayoutUtil::MakeLayout( - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } + for (Tile tile : tiles) { + for (int64 dim : tile.dimensions()) { + if (dim < 0 && dim != Tile::kCombineDimension) { + LOG(FATAL) << "Tile dimension size needs to be mininum int64 value if " + "it's negative. Value is " + << dim; + } + } + *layout.add_tiles() = tile; + } + layout.set_element_size_in_bits(element_size_in_bits); return layout; } @@ -235,6 +247,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } dimensions_in_layout[dim] = true; } + } else { + if (layout.tiles_size() != 0) { + return InvalidArgument("Only dense layouts can be tiled."); + } } return Status::OK(); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 609dba67bcdbcb11be0906b7d87a52a17ba0dfbd..9997aef465daa48ee77050e03d97cde0ea2425cc 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,7 +36,9 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(absl::Span minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4cc94c270cd64eb19761cc1044861c7d185b7888..12da214063676717aa075e66aa54974f4cc2b31b 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -317,6 +317,81 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, HumanStringWithTiling) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3, 4}, {0, 1, 2}); + Tile* tile; + + // No tiling. + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), "f32[2,3,4]{0,1,2}"); + + // 2D tile. + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + tile->add_dimensions(1024); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512,1024)}"); + + // 1D tile. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(512); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "f32[2,3,4]{0,1,2:T(512)}"); + + // 2 tiles. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 4}, {1, 2, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(16); + tile->add_dimensions(256); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(1); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,4]{1,2,0:T(16,256)(2,1)}"); + + // PRED with element size of 8 bits. + shape = ShapeUtil::MakeShapeWithLayout(PRED, {8, 8, 8}, {0, 2, 1}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)}"); + + // PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(8); + tile->add_dimensions(128); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:T(8,128)E(32)}"); + + // No tile. PRED with element size of 32 bits. + shape.mutable_layout()->clear_tiles(); + shape.mutable_layout()->set_element_size_in_bits(32); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "pred[8,8,8]{0,2,1:E(32)}"); + + // Tile with negative dimension size for combining dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {2, 3, 1004}, {2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[2,3,1004]{2,1,0:T(2,*,128)}"); + + // Tile with two negative dimensions. + shape = ShapeUtil::MakeShapeWithLayout(BF16, {8, 2, 3, 1004}, {3, 2, 1, 0}); + tile = shape.mutable_layout()->add_tiles(); + tile->add_dimensions(2); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(Tile::kCombineDimension); + tile->add_dimensions(128); + EXPECT_EQ(ShapeUtil::HumanStringWithLayout(shape), + "bf16[8,2,3,1004]{3,2,1,0:T(2,*,*,128)}"); +} + TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); auto status = diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 8600e8752cfbe072407391559d210d0b49bea511..5cd738d0f7769ceac7eb3bdbc5abd3196d9cf99c 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -44,7 +44,6 @@ namespace xla { namespace { using absl::StrCat; -using absl::StrFormat; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -1628,26 +1627,20 @@ bool LiteralBase::IsAllFloat(float value) const { return true; } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; } - return true; }); } diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc index 91e71f5d1d02d135158d0dffc140c21cf8ea5e3a..e1e22f784172b5f3850f0bc510322dfad9e7f1bb 100644 --- a/tensorflow/compiler/xla/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -186,6 +186,14 @@ bool ParseFlagsFromEnvAndDieIfUnknown( tensorflow::mutex_lock lock(env_argv_mu); auto* env_argv = &EnvArgvs()[string(envvar)]; SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized + + if (VLOG_IS_ON(1)) { + VLOG(1) << "For env var " << envvar << " found arguments:"; + for (int i = 0; i < env_argv->argc; i++) { + VLOG(1) << " argv[" << i << "] = " << env_argv->argv[i]; + } + } + bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index f22fc8b8499dd4a5329276040331a2ed9e89bea9..4a88a48f2857a327aba3600ca72191e5c7b28585 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#include "google/protobuf/duration.pb.h" +#include "absl/time/time.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/protobuf.h" @@ -43,6 +45,20 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, // dirpath along as-is. void RegisterDirectoryExpander(const std::function& expander); +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 4afb21d5c8864c2974114af2de08df4106a13a8c..55eacc1c16a76522215d27ac7cf4e801e69c9740 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -59,10 +59,6 @@ cc_library( srcs = ["local_computation_builder.cc"], hdrs = ["local_computation_builder.h"], deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -77,15 +73,38 @@ cc_library( "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:triangular_solve", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/core:lib", + "//third_party/python_runtime:headers", # buildcleaner: keep + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "xrt", + srcs = ["xrt.cc"], + hdrs = ["xrt.h"], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -93,9 +112,12 @@ cc_library( tf_py_wrap_cc( name = "pywrap_xla", - srcs = ["xla.i"], + srcs = [ + "xla.i", + ], swig_includes = [ "local_computation_builder.i", + "xla_data.i", "//tensorflow/python:platform/base.i", ], version_script = select({ @@ -112,3 +134,27 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla/service:cpu_plugin", ] + xla_python_default_plugins(), ) + +tf_py_wrap_cc( + name = "pywrap_xrt", + srcs = [ + "xrt.i", + ], + swig_includes = [ + "xla_data.i", + "//tensorflow/python:platform/base.i", + ], + version_script = select({ + "//tensorflow:darwin": "pywrap_xla_exported_symbols.lds", + "//tensorflow:windows": None, + "//conditions:default": "pywrap_xla_version_script.lds", + }), + visibility = ["//visibility:public"], + deps = [ + ":numpy_bridge", + ":xrt", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index ce4bd6f681b80a0c52579f62e3422be81d06076f..c14a01a858af414fc78a5f727372e8fa64cad4b8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,29 +20,22 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/cc/client/client_session.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" -#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" -#include "tensorflow/compiler/xrt/xrt.pb.h" -#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -51,72 +44,6 @@ limitations under the License. namespace xla { namespace swig { -// TODO(b/118641336): Factor out XRT parts into a small c++ library of their -// own. - -// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of -// device handles instead of needing to set the number of replicas at XLA -// service initialization time. -tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); -int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; -LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; - -string* GetPlatformNameString() { - static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = - new string("Host"); - return platform_name_string; -} - -Status InitializeReplicaCount(int replica_count) { - if (replica_count < 1) { - return InvalidArgument("Replica count must be >= 1; got %d.", - replica_count); - } - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the replica count to %d, but a local XLA service was " - "previously created with a replica count of %d.", - replica_count, g_replica_count); - } - g_replica_count = replica_count; - return Status::OK(); -} - -Status InitializePlatformName(const string& platform_name) { - string* g_platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return FailedPrecondition( - "Attempted to set the platform name to %s, but a local XLA service was " - "previously created with a platform name of %s.", - platform_name, *g_platform_name); - } - TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); - *g_platform_name = platform_name; - return Status::OK(); -} - -int GetReplicaCount() { - tensorflow::mutex_lock lock(g_local_client_mutex); - return g_replica_count; -} - -StatusOr GetOrCreateLocalClient() { - string* platform_name = GetPlatformNameString(); - tensorflow::mutex_lock lock(g_local_client_mutex); - if (g_local_client != nullptr) { - return g_local_client; - } - LocalClientOptions options; - options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); - options.set_number_of_replicas(g_replica_count); - TF_ASSIGN_OR_RETURN(g_local_client, - ClientLibrary::GetOrCreateLocalClient(options)); - CHECK(g_local_client != nullptr); - return g_local_client; -} - Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { const char* name = "xla._CPU_CUSTOM_CALL_TARGET"; if (!PyCapsule_IsValid(capsule, name)) { @@ -131,62 +58,66 @@ Status RegisterCpuCustomCallTarget(const string& fn_name, PyObject* capsule) { return Status::OK(); } -Status TransferToInfeedLocal(const Literal& literal) { - VLOG(1) << "Infeeding literal without replica number; shape: " - << literal.shape(); - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); -} +LocalClient::LocalClient(xla::LocalClient* client) : client_(client) {} -Status TransferToInfeedLocalReplica(const Literal& literal, - int replica_number) { - VLOG(1) << "Infeeding shape " << literal.shape() - << " to replica number: " << replica_number; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferToInfeedLocal(literal, device_ordinal); +/* static */ StatusOr LocalClient::Get( + const string& platform_name) { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform(platform_name)); + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("Platform %s has no visible devices.", + platform_name); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + CHECK(client != nullptr); + return LocalClient(client); } -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number) { - VLOG(1) << "Outfeeding literal from replica number: " << replica_number - << " shape: " << shape; - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - return client->TransferFromOutfeedLocal(shape, device_ordinal); +// Returns the number of devices known to the XLA client. +int LocalClient::DeviceCount() const { return client_->device_count(); } + +Status LocalClient::TransferToInfeed(const Literal& literal, + int device_ordinal) { + VLOG(1) << "Infeeding literal to device " << device_ordinal + << "; shape: " << literal.shape(); + return client_->TransferToInfeed(literal, device_ordinal); } -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); +StatusOr LocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + VLOG(1) << "Outfeeding literal from device " << device_ordinal + << "; shape: " << shape; + return client_->TransferFromOutfeed(&shape, device_ordinal); } /* static */ StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(int device_ordinal, - client->ReplicaNumberToDeviceOrdinal(replica_number)); - VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " - << replica_number << "/" << device_ordinal; + const LocalClient& client, int device_ordinal) { + VLOG(1) << "Creating shaped buffer from literal on device ordinal: " + << device_ordinal; + auto literal_to_buffer = [&](const Literal& arg) { + return client.client()->LiteralToShapedBuffer( + arg, device_ordinal, client.client()->backend().memory_allocator()); + }; + StatusOr buf = [&] { if (shape_with_layout) { Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, device_ordinal, relaid); + return literal_to_buffer(relaid); } - return ToBuffer(client, device_ordinal, argument); + return literal_to_buffer(argument); }(); TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie(), client.client()); } -LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) - : shaped_buffer_(std::move(shaped_buffer)) {} +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, + xla::LocalClient* client) + : shaped_buffer_(std::move(shaped_buffer)), client_(client) {} const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; @@ -199,8 +130,7 @@ const Shape& LocalShapedBuffer::shape() const { } StatusOr LocalShapedBuffer::ToLiteral() const { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - return client->ShapedBufferToLiteral(*shaped_buffer()); + return client_->ShapedBufferToLiteral(*shaped_buffer()); } LocalShapedBufferTuple::LocalShapedBufferTuple( @@ -231,120 +161,77 @@ StatusOr LocalShapedBufferTuple::Release(int i) { int64 LocalShapedBufferTuple::size() const { return elements_.size(); } -XrtAllocation::XrtAllocation(int64 handle, Shape shape, - const string& session_target) - : handle_(handle), shape_(shape), session_target_(session_target) {} - -XrtAllocation::~XrtAllocation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } +StatusOr LocalShapedBuffer::DestructureTuple() { + const Shape tuple_shape = shape(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); } -} - -/* static */ -StatusOr XrtAllocation::FromLiteral( - const Literal& argument, const string& session_target) { - xrt::XLAAllocation alloc; - *alloc.mutable_value() = argument.ToProto(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto literal_string = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); - TF_RETURN_IF_ERROR(root.status()); - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({literal_string, alloc.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + DeviceMemoryAllocator* allocator = shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = Release(); - int64 handle = outputs[0].scalar()(); - return new XrtAllocation(handle, argument.shape(), session_target); -} - -const int64 XrtAllocation::handle() const { return handle_; } - -const Shape& XrtAllocation::shape() const { return shape_; } - -StatusOr XrtAllocation::ToLiteral() const { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto allocation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); - TF_RETURN_IF_ERROR(root.status()); + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({allocation_handle, handle()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + ShapeTree& shape_tree = tuple_buffer.buffers(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - xla::LiteralProto response; - TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); - return Literal::CreateFromProto(response); -} + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); -XrtAllocationTuple::XrtAllocationTuple(std::vector elements) - : elements_(std::move(elements)) { - for (auto* element : elements_) { - CHECK(element != nullptr); + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator), client_)); } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); } -XrtAllocationTuple::~XrtAllocationTuple() { - for (XrtAllocation* element : elements_) { - if (element != nullptr) { - delete element; - } - } -} +LocalExecutable::LocalExecutable( + std::unique_ptr executable, + xla::DeviceAssignment device_assignment, xla::LocalClient* client) + : executable_(std::move(executable)), + device_assignment_(std::move(device_assignment)), + client_(client) {} -StatusOr XrtAllocationTuple::Release(int i) { - XrtAllocation* element = elements_[i]; - if (element == nullptr) { - return InvalidArgument("Attempted to release already-released element %d.", - i); +std::vector LocalExecutable::DeviceOrdinals() const { + int num_replicas = device_assignment_.replica_count(); + std::vector device_ordinals; + device_ordinals.reserve(num_replicas); + for (int i = 0; i < num_replicas; ++i) { + device_ordinals.push_back(device_assignment_(i, 0)); } - elements_[i] = nullptr; - return element; + return device_ordinals; } -int64 XrtAllocationTuple::size() const { return elements_.size(); } - -CompiledLocalComputation::CompiledLocalComputation( - std::unique_ptr executable) - : executable_(std::move(executable)) {} - -StatusOr CompiledLocalComputation::Execute( +StatusOr LocalExecutable::Execute( absl::Span argument_handles) { if (num_replicas() != 1) { return InvalidArgument( "Attempted to execute computation with %d replicas using Execute()", num_replicas()); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - 1, /*computation_count=*/1)); StatusOr result_buffer_status; - const int device_ordinal = device_assignment(0, 0); + const int device_ordinal = device_assignment_(0, 0); VLOG(3) << "Replica 0 mapped to device ordinal for execution: " << device_ordinal; @@ -356,10 +243,10 @@ StatusOr CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); result_buffer_status = executable_->Run(argument_buffers, options); @@ -369,13 +256,13 @@ StatusOr CompiledLocalComputation::Execute( "%s.", result_buffer_status.status().ToString()); } - return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(), + client_); } -StatusOr CompiledLocalComputation::ExecutePerReplica( +StatusOr LocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); - const int num_devices = client->device_count(); + const int num_devices = client_->device_count(); if (argument_handles.size() != num_replicas()) { return InvalidArgument( @@ -390,14 +277,9 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( VLOG(1) << "Executing with " << num_replicas() << " replicas."; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - client->backend().computation_placer()->AssignDevices( - num_replicas(), /*computation_count=*/1)); - std::vector> results(num_replicas()); - auto execute = [this, client, &device_assignment, &argument_handles, - &results](int replica) { - const int device_ordinal = device_assignment(replica, 0); + auto execute = [this, &argument_handles, &results](int replica) { + const int device_ordinal = device_assignment_(replica, 0); VLOG(3) << "Replica " << replica << " mapped to device ordinal for execution: " << device_ordinal; @@ -409,10 +291,10 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); + options.set_allocator(client_->backend().memory_allocator()); options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); + client_->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment_); StatusOr result_buffer_status = executable_->Run(argument_buffers, options); @@ -444,151 +326,43 @@ StatusOr CompiledLocalComputation::ExecutePerReplica( replica, statusor.status().ToString()); } wrapped_results[replica] = - new LocalShapedBuffer(std::move(statusor).ValueOrDie()); + new LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_); } return new LocalShapedBufferTuple(std::move(wrapped_results)); } -static StatusOr GetReturnValueShape(const XlaComputation& computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation.GetProgramShape()); - return std::move(*program_shape.mutable_result()); -} - -CompiledXrtComputation::CompiledXrtComputation( - const ProgramShape& program_shape, int64 handle, - const string& session_target) - : program_shape_(program_shape), - handle_(handle), - session_target_(session_target) {} - -CompiledXrtComputation::~CompiledXrtComputation() { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto release = - tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); - if (!root.status().ok()) { - LOG(ERROR) << root.status(); - return; - } - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({computation_handle, handle()}); - std::vector outputs; - auto status = session.Run(inputs, {}, {release}, &outputs); - if (!status.ok()) { - LOG(ERROR) << status; - return; - } -} - -StatusOr CompiledXrtComputation::Execute( - absl::Span argument_handles) { - const int num_expected_arguments = program_shape().parameters().size(); - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - std::vector arguments; - arguments.reserve(num_expected_arguments); - for (int i = 0; i < num_expected_arguments; ++i) { - arguments.push_back( - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); - } - auto computation_handle = - tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto execution_config = - tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto execute = tensorflow::ops::XRTExecute(root, computation_handle, - execution_config, arguments); - TF_RETURN_IF_ERROR(root.status()); - - TF_RET_CHECK(argument_handles.size() == arguments.size()); - - xrt::XRTExecutionConfig e; - e.set_release_input_handles(false); - e.set_release_compilation_handle(false); - - tensorflow::ClientSession session(root, session_target_); - tensorflow::ClientSession::FeedType inputs; - for (int i = 0; i < arguments.size(); ++i) { - inputs.insert({arguments[i], argument_handles[i]->handle()}); - } - inputs.insert({computation_handle, handle()}); - inputs.insert({execution_config, e.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); - - int64 output = outputs[0].scalar()(); - return new XrtAllocation(output, program_shape().result(), session_target_); -} - -const ProgramShape& CompiledXrtComputation::program_shape() const { - return program_shape_; -} - -int64 CompiledXrtComputation::handle() const { return handle_; } - -LocalComputation::LocalComputation(XlaComputation computation) +Computation::Computation(XlaComputation computation) : computation_(std::move(computation)) {} -StatusOr LocalComputation::Compile( +StatusOr Computation::Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options) { + const ExecutableBuildOptions* build_options, const LocalClient& client) { std::vector argument_shape_pointers; argument_shape_pointers.reserve(argument_shapes.size()); for (auto& argument_shape : argument_shapes) { argument_shape_pointers.push_back(&argument_shape); } - TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient()); ExecutableBuildOptions options; if (build_options != nullptr) { options = *build_options; } TF_ASSIGN_OR_RETURN( auto local_executable, - client->Compile(computation_, argument_shape_pointers, options)); - return new CompiledLocalComputation(std::move(local_executable)); -} - -StatusOr LocalComputation::CompileForXrt( - const std::vector& argument_shapes, const string& session_target) { - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); - auto compile = tensorflow::ops::XRTCompile(root, program); - TF_RETURN_IF_ERROR(root.status()); - - xrt::XLAComputation c; - auto config = c.mutable_config(); - ProgramShape shapes; - for (auto& shape : argument_shapes) { - *shapes.add_parameters() = shape; - } - TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(&shapes); - *config->mutable_program_shape() = shapes.ToProto(); - auto snapshot = computation().Snapshot().ValueOrDie(); - *c.mutable_hlo_snapshot() = *snapshot; - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - inputs.insert({program, c.SerializeAsString()}); - std::vector outputs; - TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + client.client()->Compile(computation_, argument_shape_pointers, options)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + client.client()->backend().computation_placer()->AssignDevices( + options.num_replicas(), /*computation_count=*/1)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation().GetProgramShape()); - int64 handle = outputs[0].scalar()(); - return new CompiledXrtComputation(program_shape, handle, session_target); + return new LocalExecutable(std::move(local_executable), + std::move(device_assignment), client.client()); } -const XlaComputation& LocalComputation::computation() const { - return computation_; -} +const XlaComputation& Computation::computation() const { return computation_; } -string LocalComputation::GetSerializedProto() const { +string Computation::GetSerializedProto() const { string result; if (!computation_.proto().SerializeToString(&result)) { LOG(ERROR) << "Failed to serialize the HloModuleProto."; @@ -597,101 +371,129 @@ string LocalComputation::GetSerializedProto() const { return result; } -StatusOr LocalComputation::GetReturnValueShape() const { - return swig::GetReturnValueShape(computation_); +StatusOr Computation::GetHloText() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + HloPrintOptions options; + options = HloPrintOptions::ShortParsable(); + options.set_print_large_constants(false); + return hlo_module->ToString(options); +} + +StatusOr Computation::GetHloDotGraph() const { + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + computation_.proto(), GetDebugOptionsFromFlags())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(computation_.proto(), module_config)); + hlo_graph_dumper::DotGraphOptions options; + options.debug_options = &hlo_module->config().debug_options(); + return hlo_graph_dumper::HloComputationToDotGraph( + *hlo_module->entry_computation(), options); +} + +StatusOr Computation::GetProgramShape() const { + return computation_.GetProgramShape(); +} + +StatusOr Computation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape shape, computation_.GetProgramShape()); + return std::move(*shape.mutable_result()); } LocalOp::LocalOp(const XlaOp& op) : op_(op) {} const XlaOp& LocalOp::op() const { return op_; } -LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) +ComputationBuilder::ComputationBuilder(const string& computation_name) : builder_(computation_name) {} -void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { +void ComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); } -void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } +void ComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } -StatusOr LocalComputationBuilder::Build() { +StatusOr ComputationBuilder::Build() { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp ComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, const string& name) { return xla::Parameter(&builder_, parameter_number, shape, name); } -StatusOr LocalComputationBuilder::BuildWithRoot( - const LocalOp& root) { +StatusOr ComputationBuilder::BuildWithRoot(const LocalOp& root) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { +StatusOr ComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } -StatusOr LocalComputationBuilder::GetReturnValueShape() { +StatusOr ComputationBuilder::GetReturnValueShape() { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); return program_shape.result(); } -LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp ComputationBuilder::Infeed(const Shape& shape) { return xla::Infeed(&builder_, shape); } -void LocalComputationBuilder::Outfeed(const LocalOp& operand, - const Shape& shape, - const string& outfeed_config) { +void ComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, + const string& outfeed_config) { xla::Outfeed(operand.op(), shape, outfeed_config); } -LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { +LocalOp ComputationBuilder::ConstantLiteral(const Literal& literal) { return xla::ConstantLiteral(&builder_, literal); } -LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) { +LocalOp ComputationBuilder::Iota(PrimitiveType element_type, int64 size) { return xla::Iota(&builder_, element_type, size); } -LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape, - int64 dimension) { +LocalOp ComputationBuilder::BroadcastedIota(const Shape& shape, + int64 dimension) { return xla::Iota(&builder_, shape, dimension); } -LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, absl::Span broadcast_sizes) { +LocalOp ComputationBuilder::Broadcast(const LocalOp& operand, + absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } -LocalOp LocalComputationBuilder::BroadcastInDim( +LocalOp ComputationBuilder::BroadcastInDim( const LocalOp& operand, absl::Span out_dim_sizes, absl::Span broadcast_dimensions) { return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); } -LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, - const LocalOp& padding_value, - const PaddingConfig& padding_config) { +LocalOp ComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, - absl::Span dimensions, - absl::Span new_sizes) { +LocalOp ComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::AllToAll( +LocalOp ComputationBuilder::AllToAll( const LocalOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, absl::Span replica_groups) { std::vector rg(replica_groups.size()); @@ -702,39 +504,38 @@ LocalOp LocalComputationBuilder::AllToAll( split_count, rg); } -LocalOp LocalComputationBuilder::CrossReplicaSum( +LocalOp ComputationBuilder::CrossReplicaSum( const LocalOp& operand, absl::Span replica_groups) { return xla::CrossReplicaSum(operand.op(), replica_groups); } -LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides) { +LocalOp ComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } -LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, - int64 start_index, - int64 limit_index, int64 stride, - int64 dimno) { +LocalOp ComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } -LocalOp LocalComputationBuilder::DynamicSlice( - const LocalOp& operand, const LocalOp& start_indices, - absl::Span slice_sizes) { +LocalOp ComputationBuilder::DynamicSlice(const LocalOp& operand, + const LocalOp& start_indices, + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -LocalOp LocalComputationBuilder::DynamicUpdateSlice( - const LocalOp& operand, const LocalOp& update, - const LocalOp& start_indices) { +LocalOp ComputationBuilder::DynamicUpdateSlice(const LocalOp& operand, + const LocalOp& update, + const LocalOp& start_indices) { return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, - int64 dimension) { +LocalOp ComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -743,18 +544,18 @@ LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, return xla::ConcatInDim(&builder_, xla_ops, dimension); } -LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, +LocalOp ComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span> padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter) { + const LocalOp& init_value, const Computation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { +LocalOp ComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -764,22 +565,22 @@ LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { return xla::Tuple(&builder_, xla_ops); } -LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, - int64 index) { +LocalOp ComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { return xla::GetTupleElement(tuple_data.op(), index); } -LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { +LocalOp ComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { return xla::Dot(lhs.op(), rhs.op()); } -LocalOp LocalComputationBuilder::DotGeneral( +LocalOp ComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -LocalOp LocalComputationBuilder::ConvGeneralDilated( +LocalOp ComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -791,18 +592,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( feature_group_count); } -LocalOp LocalComputationBuilder::ConvertElementType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::ConvertElementType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::BitcastConvertType( - const LocalOp& operand, PrimitiveType new_element_type) { +LocalOp ComputationBuilder::BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type) { return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, - absl::Span operands) { +LocalOp ComputationBuilder::Call(const Computation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -811,7 +612,7 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, return xla::Call(&builder_, local_computation.computation(), xla_ops); } -LocalOp LocalComputationBuilder::CustomCall( +LocalOp ComputationBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const std::vector& operand_shapes_with_layout, @@ -826,19 +627,19 @@ LocalOp LocalComputationBuilder::CustomCall( operand_shapes_with_layout, opaque); } -LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, absl::Span permutation) { +LocalOp ComputationBuilder::Transpose(const LocalOp& operand, + absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, - absl::Span dimensions) { +LocalOp ComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map(absl::Span operands, - const LocalComputation& local_computation, - absl::Span dimensions) { +LocalOp ComputationBuilder::Map(absl::Span operands, + const Computation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -849,17 +650,17 @@ LocalOp LocalComputationBuilder::Map(absl::Span operands, dimensions); } -LocalOp LocalComputationBuilder::Reduce( +LocalOp ComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } -LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( +LocalOp ComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -871,51 +672,50 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( padding); } -LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, - const LocalOp& sigma, - const Shape& shape) { +LocalOp ComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape) { return xla::RngNormal(mu.op(), sigma.op(), shape); } -LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, - const Shape& shape) { +LocalOp ComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { return xla::RngUniform(a.op(), b.op(), shape); } -LocalOp LocalComputationBuilder::While(const LocalComputation& condition, - const LocalComputation& body, - const LocalOp& init) { +LocalOp ComputationBuilder::While(const Computation& condition, + const Computation& body, + const LocalOp& init) { return xla::While(condition.computation(), body.computation(), init.op()); } -LocalOp LocalComputationBuilder::Conditional( - const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation) { +LocalOp ComputationBuilder::Conditional(const LocalOp& predicate, + const LocalOp& true_operand, + const Computation& true_computation, + const LocalOp& false_operand, + const Computation& false_computation) { return xla::Conditional(predicate.op(), true_operand.op(), true_computation.computation(), false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { +StatusOr ComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } -LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { +LocalOp ComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { return xla::Sort(operand.op(), {}, dimension); } -LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, - const LocalOp& values, - int64 dimension) { +LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, int64 dimension) { return xla::Sort(keys.op(), {values.op()}, dimension); } -LocalOp LocalComputationBuilder::Cholesky(const LocalOp& a) { +LocalOp ComputationBuilder::Cholesky(const LocalOp& a) { return xla::Cholesky(a.op()); } -LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { +LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { XlaBuilder* builder = a.op().builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto qr, xla::QRDecomposition(a.op(), full_matrices)); @@ -923,16 +723,16 @@ LocalOp LocalComputationBuilder::QR(const LocalOp& a, bool full_matrices) { }); } -LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a, - const LocalOp& b, - bool left_side, bool lower, - bool transpose_a, - bool conjugate_a) { - return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a, - conjugate_a); +LocalOp ComputationBuilder::TriangularSolve(const LocalOp& a, const LocalOp& b, + bool left_side, bool lower, + bool unit_diagonal, + int transpose_a) { + return xla::TriangularSolve( + a.op(), b.op(), left_side, lower, unit_diagonal, + xla::TriangularSolveOptions::Transpose(transpose_a)); } -LocalOp LocalComputationBuilder::Gather( +LocalOp ComputationBuilder::Gather( const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes) { @@ -940,24 +740,24 @@ LocalOp LocalComputationBuilder::Gather( slice_sizes); } -LocalOp LocalComputationBuilder::Scatter( +LocalOp ComputationBuilder::Scatter( const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers) { return xla::Scatter(input.op(), scatter_indices.op(), updates.op(), update_computation.computation(), dimension_numbers); } -StatusOr LocalComputationBuilder::BuildConstantSubGraph( +StatusOr ComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.BuildConstantSubGraph(operand.op())); - return new LocalComputation(std::move(computation)); + return new Computation(std::move(computation)); } -#define _FORWARD(method_name, return_sig, args_sig, args) \ - return_sig LocalComputationBuilder::method_name args_sig { \ - return xla::method_name args; \ +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig ComputationBuilder::method_name args_sig { \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -999,6 +799,7 @@ _FORWARD_BINOP(Atan2) _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) +_FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -1044,108 +845,9 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; } -void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } - -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { - delete computation; -} - -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { - delete computation; -} - -void DeleteLocalComputation(LocalComputation* computation) { - delete computation; -} - -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer) { - const Shape tuple_shape = local_shaped_buffer->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - DeviceMemoryAllocator* allocator = - local_shaped_buffer->shaped_buffer()->memory_allocator(); - ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); +void DeleteLocalExecutable(LocalExecutable* computation) { delete computation; } - // Extract some metadata we use to construct scoped buffers. - const se::Platform* platform = tuple_buffer.platform(); - int device_ordinal = tuple_buffer.device_ordinal(); - - ShapeTree& shape_tree = tuple_buffer.buffers(); - std::vector results; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - // Create a shaped buffer for this destructured tuple element. - const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); - VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; - ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); - - ShapeUtil::ForEachSubshape( - subshape, [&](const Shape& s, const ShapeIndex& index) { - ShapeIndex original(index); - original.push_front(i); - se::DeviceMemoryBase* device_memory = - shape_tree.mutable_element(original); - shaped_buffer.set_buffer(*device_memory, index); - *device_memory = se::DeviceMemoryBase(); - }); - - VLOG(3) << "Completed tuple element: " << i; - results.push_back(new LocalShapedBuffer( - ScopedShapedBuffer(std::move(shaped_buffer), allocator))); - } - // Deallocate the root buffer. - se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); - TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); - return new LocalShapedBufferTuple(std::move(results)); -} - -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target) { - const Shape& tuple_shape = allocation->shape(); - - if (!tuple_shape.IsTuple()) { - return InvalidArgument( - "Attemped to destructure a LocalShapedBuffer that did not have a tuple " - "shape; shape: %s", - ShapeUtil::HumanString(tuple_shape)); - } - - tensorflow::Scope root = tensorflow::Scope::NewRootScope(); - auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); - auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); - auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); - TF_RETURN_IF_ERROR(root.status()); - - tensorflow::ClientSession session(root, session_target); - tensorflow::ClientSession::FeedType inputs; - std::vector results; - for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { - inputs.clear(); - inputs.insert({base_handle, allocation->handle()}); - inputs.insert({shape_index, {i}}); - std::vector outputs; - auto status = session.Run(inputs, {subtuple}, &outputs); - if (!status.ok()) { - // Clean up before returning non-ok status. - for (int j = 0; j < results.size(); ++j) { - delete results[j]; - } - return status; - } - const int64 subtuple_handle = outputs[0].scalar()(); - const Shape& subtuple_shape = - ShapeUtil::GetTupleElementShape(tuple_shape, i); - results.push_back( - new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); - } - return new XrtAllocationTuple(std::move(results)); -} +void DeleteComputation(Computation* computation) { delete computation; } } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e3af88f82559c32a7267a56c87d3bafda01b934d..66b1cce7fb598388af40940ea2ed52ac2f8ee8e1 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -22,9 +22,6 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -35,42 +32,42 @@ limitations under the License. namespace xla { namespace swig { -// Initializes the number of replicas that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializeReplicaCount(int replica_count); - -// Initializes the platform name that XLA will be initialized with (when -// first obtaining a handle to the local XLA service). If this is called after -// the handle to the local XLA service has been established, then an error is -// returned. -Status InitializePlatformName(const string& platform_name); - -// Returns the replica count that is currently set, regardless of whether the -// local XLA service has been instantiated yet or not. -int GetReplicaCount(); - // Registers a 'fn_capsule' as a CPU custom call target. // 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name // "xla._CPU_CUSTOM_CALL_TARGET". Status RegisterCpuCustomCallTarget(const string& name, PyObject* fn_capsule); -// Wraps the local client's infeed-transfer function. -// -// The default device ordinal (0) is used. -Status TransferToInfeedLocal(const Literal& literal); +// Wrapper around an xla::LocalClient. +class LocalClient { + public: + // Initializes a local XLA client for `platform_name`. Returns an error if no + /// such platform exists, or if the platform has no visible devices. + static StatusOr Get(const string& platform_name); + + // Copyable and moveable; the class is just a wrapper around a + // xla::LocalClient pointer for convenient SWIG wrapping. + + // Returns the number of devices known to the XLA client. + int DeviceCount() const; + + // Wraps the local client's infeed-transfer function. + // + // The default device ordinal (0) is used. + Status TransferToInfeed(const Literal& literal, int device_ordinal); -// Transfers the given literal to the infeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + // Transfers a literal of the given shape from the outfeed of the given + // replica. + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); -// Transfers a literal of the given shape from the outfeed of the given replica. -// -// The replica number is resolved to an appropriate device ordinal. -StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, - int replica_number); + xla::LocalClient* client() const { return client_; } + + private: + LocalClient(xla::LocalClient* client); + + xla::LocalClient* client_; +}; + +class LocalShapedBufferTuple; // Represents a reference to literals that live in a device-allocated buffer via // XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a @@ -79,9 +76,9 @@ class LocalShapedBuffer { public: static StatusOr FromLiteral( const Literal& argument, const absl::optional& shape_with_layout, - int replica_number); + const LocalClient& client, int device_ordinal); - LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer, xla::LocalClient* client); StatusOr ToLiteral() const; const Shape& shape() const; const ScopedShapedBuffer* shaped_buffer() const; @@ -90,8 +87,13 @@ class LocalShapedBuffer { // analogous to std::unique_ptr::release(). ShapedBuffer Release(); + // Destructures a tuple-valued LocalShapedBuffer into its constitutent + // elements in LocalShapedBufferTuple form. + StatusOr DestructureTuple(); + private: ScopedShapedBuffer shaped_buffer_; + xla::LocalClient* client_; }; // Result of a tuple destructuring operation on a LocalShapedBuffer -- this @@ -117,73 +119,21 @@ class LocalShapedBufferTuple { std::vector elements_; }; -// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements -// in LocalShapedBufferTuple form. -StatusOr DestructureLocalShapedBufferTuple( - LocalShapedBuffer* local_shaped_buffer); - -// Represents a reference to literals that live in a device-allocated buffer via -// XRT. Specifically, wraps an int64 handle produced by running the allocation -// graph, and an XLA shape to track the referent's shape. -class XrtAllocation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which allocation and deallocation - // graphs are run. - static StatusOr FromLiteral(const Literal& argument, - const string& session_target); - - XrtAllocation(int64 handle, Shape shape, const string& session_target); - ~XrtAllocation(); - StatusOr ToLiteral() const; - const Shape& shape() const; - const int64 handle() const; - - private: - const int64 handle_; - const Shape shape_; - const string session_target_; -}; - -// Result of a tuple destructuring operation on an XrtAllocation. -class XrtAllocationTuple { - public: - // Note: any XrtAllocation elements that are not Release()'d will be - // deallocated in the destructor. - explicit XrtAllocationTuple(std::vector elements); - - ~XrtAllocationTuple(); - - // Releases the ith element to the caller. Further attempts to release the ith - // element will return an invalid argument error. - StatusOr Release(int i); - - // Returns the number of elements in the destructured tuple. - int64 size() const; - - private: - std::vector elements_; -}; - -// Destructures a tuple-valued XrtAllocation into its constitutent elements -// in XrtAllocationTuple form. -// -// Accepts a `session_target` argument, used in constructing the -// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, -// and passed along in constructing each constituent XrtAllocation. -StatusOr DestructureXrtAllocationTuple( - XrtAllocation* allocation, const string& session_target); - // Represents a compiled computation that can be executed given handles to // device-allocated literals. Specifically, wraps an XLA LocalExecutable. -class CompiledLocalComputation { +class LocalExecutable { public: - CompiledLocalComputation(std::unique_ptr executable); + LocalExecutable(std::unique_ptr executable, + xla::DeviceAssignment device_assignment, + xla::LocalClient* client); int num_replicas() const { return executable_->build_options().num_replicas(); } + // Returns the device ordinals to which each replica is assigned. + std::vector DeviceOrdinals() const; + StatusOr Execute( absl::Span argument_handles); @@ -194,47 +144,22 @@ class CompiledLocalComputation { absl::Span > argument_handles); private: - std::unique_ptr executable_; + const std::unique_ptr executable_; + const xla::DeviceAssignment device_assignment_; + xla::LocalClient* const client_; }; -// Represents a compiled computation that can be executed given handles to -// device-allocated literals. Specifically, wraps an XRT computation handle. -class CompiledXrtComputation { - public: - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the execution graph is run. - CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, - const string& session_target); - ~CompiledXrtComputation(); - - StatusOr Execute( - absl::Span argument_handles); - - const ProgramShape& program_shape() const; - int64 handle() const; - - private: - const ProgramShape program_shape_; - const int64 handle_; - const string session_target_; -}; - -// Wraps a XlaComputation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a ComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. -class LocalComputation { +class Computation { public: - LocalComputation(XlaComputation computation); + Computation(XlaComputation computation); - StatusOr Compile( + StatusOr Compile( const std::vector& argument_shapes, - const ExecutableBuildOptions* build_options); - - // Accepts a `session_target` argument, used in constructing the - // `tensorflow::ClientSession` instance in which the compilation graph is run. - StatusOr CompileForXrt( - const std::vector& argument_shapes, const string& session_target); + const ExecutableBuildOptions* build_options, const LocalClient& client); const XlaComputation& computation() const; @@ -243,6 +168,15 @@ class LocalComputation { // string on failure. string GetSerializedProto() const; + // Returns the computation in human-readable HLO text format. + StatusOr GetHloText() const; + + // Returns the computation in graphviz dot format. + StatusOr GetHloDotGraph() const; + + // Returns the program shape for this computation. + StatusOr GetProgramShape() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -250,7 +184,7 @@ class LocalComputation { XlaComputation computation_; }; -// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// Wraps a XlaOp produced by a ComputationBuilder. This class is intended // to be made available to Python via SWIG. class LocalOp { public: @@ -267,20 +201,20 @@ class LocalOp { // Python. // - Set up the underlying builder to use the client library's // LocalClient. -// - Wrap Computations in LocalComputations for Python access. -// - Correspondingly unwrap incoming LocalComputations. -class LocalComputationBuilder { +// - Wrap Computations in Computations for Python access. +// - Correspondingly unwrap incoming Computations. +class ComputationBuilder { public: - LocalComputationBuilder(const string& computation_name); + ComputationBuilder(const string& computation_name); void SetOpMetadata(const OpMetadata& metadata); void ClearOpMetadata(); - // Returns an owned LocalComputation to the caller on success. - StatusOr Build(); + // Returns an owned Computation to the caller on success. + StatusOr Build(); - // Returns an owned LocalComputation to the caller on success with given root. - StatusOr BuildWithRoot(const LocalOp& root); + // Returns an owned Computation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -339,11 +273,11 @@ class LocalComputationBuilder { LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( - const LocalOp& operand, const LocalComputation& select, + const LocalOp& operand, const Computation& select, absl::Span window_dimensions, absl::Span window_strides, absl::Span > padding, const LocalOp& source, - const LocalOp& init_value, const LocalComputation& scatter); + const LocalOp& init_value, const Computation& scatter); LocalOp Tuple(absl::Span elements); @@ -369,7 +303,7 @@ class LocalComputationBuilder { LocalOp BitcastConvertType(const LocalOp& operand, PrimitiveType new_element_type); - LocalOp Call(const LocalComputation& local_computation, + LocalOp Call(const Computation& local_computation, absl::Span operands); LocalOp CustomCall(const string& call_target_name, @@ -384,16 +318,16 @@ class LocalComputationBuilder { LocalOp Rev(const LocalOp& operand, absl::Span dimensions); LocalOp Map(absl::Span operands, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, - const LocalComputation& local_computation, + const Computation& local_computation, absl::Span window_dimensions, absl::Span window_strides, absl::Span base_dilations, @@ -405,13 +339,13 @@ class LocalComputationBuilder { LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - LocalOp While(const LocalComputation& condition, const LocalComputation& body, + LocalOp While(const Computation& condition, const Computation& body, const LocalOp& init); LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, - const LocalComputation& true_computation, + const Computation& true_computation, const LocalOp& false_operand, - const LocalComputation& false_computation); + const Computation& false_computation); StatusOr IsConstant(const LocalOp& operand); @@ -424,19 +358,21 @@ class LocalComputationBuilder { LocalOp Cholesky(const LocalOp& a); + // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose + // enum. We use an integer here so we don't have to teach SWIG about the + // enum. LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a); + bool lower, bool unit_diagonal, int transpose_a); LocalOp Gather(const LocalOp& input, const LocalOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); LocalOp Scatter(const LocalOp& input, const LocalOp& scatter_indices, - const LocalOp& updates, - const LocalComputation& update_computation, + const LocalOp& updates, const Computation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - StatusOr BuildConstantSubGraph(const LocalOp& operand); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; @@ -478,6 +414,7 @@ class LocalComputationBuilder { _FORWARD_BINOP(Pow) _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) + _FORWARD_UNOP(Clz) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) _FORWARD_UNOP(Expm1) @@ -525,10 +462,8 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); -void DeleteXrtAllocation(XrtAllocation* allocation); -void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); -void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); -void DeleteLocalComputation(LocalComputation* computation); +void DeleteLocalExecutable(LocalExecutable* computation); +void DeleteComputation(Computation* computation); } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 7b2f69d6ecf44f492f70351b38997530567b5277..7d7a860baa03e99cc254b7596fb5f9d41acbef20 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -23,11 +23,13 @@ limitations under the License. // C++ Python // -------------------------------------+--------------------------------------- // Span <- sequence of int +// vector -> sequence of int // Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int // Span> <- sequence of int pairs @@ -97,7 +99,7 @@ limitations under the License. // wrapped in a Python class (xla_client.Shape) so as not to expose // the raw pair externally. // -// Other SWIG object wrappers (e.g. of LocalComputation) are further +// Other SWIG object wrappers (e.g. of Computation) are further // wrapped by xla_client in order to set up a custom destructor that // triggers memory deallocation on the C++ side. @@ -107,6 +109,7 @@ limitations under the License. %nothread; %include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" %{ // Must be included first @@ -124,87 +127,6 @@ limitations under the License. using namespace xla; using namespace xla::swig; -namespace xla { - -namespace swig { - -bool GetIntAttr(PyObject* o, const char* field, int64* result) { - PyObject* fo = PyObject_GetAttrString(o, field); - if (!fo) { - return false; - } - const int64 value = numpy::PyIntOrPyLongToLong(fo); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(fo); - return false; - } - Py_DECREF(fo); - *result = value; - return true; -} - -// Returns "ok"; true if there is no error, false if there was an error. -bool HandleStringAttribute(PyObject* o, - const char* attr_name, - std::function f) { - if (!PyObject_HasAttrString(o, attr_name)) { - return true; // It's ok for the object to not have the attribute. - } - PyObject* attr = PyObject_GetAttrString(o, attr_name); - if (attr == nullptr) { - return false; // An error occurred getting the attribute. - } - if (attr == Py_None) { - Py_DECREF(attr); - return true; // The attribute is None, which we consider ok. - } - if (!PyString_Check(attr)) { - string message = absl::StrFormat("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr)); - PyErr_SetString(PyExc_TypeError, message.c_str()); - Py_DECREF(attr); - return false; // Type error, not ok. - } - f(PyString_AsString(attr)); - Py_DECREF(attr); - return true; // Handled string attribute, ok! -} - -bool HandleRepeatedInt64Attribute( - PyObject* o, const char* attr_name, - tensorflow::protobuf::RepeatedField* field) { - PyObject* seq = PyObject_GetAttrString(o, attr_name); - if (!seq) { - return false; - } - - int length = PySequence_Size(seq); - if (length == -1) { - Py_DECREF(seq); - return false; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(seq, i); - if (!item) { - Py_DECREF(seq); - return false; - } - const int64 dimension = numpy::PyIntOrPyLongToLong(item); - if (dimension == -1 && PyErr_Occurred()) { - Py_DECREF(item); - Py_DECREF(seq); - return false; - } - *field->Add() = dimension; - Py_DECREF(item); - } - Py_DECREF(seq); - return true; -} - -} // namespace swig -} // namespace xla %} // Required to use PyArray_* functions. @@ -212,57 +134,6 @@ bool HandleRepeatedInt64Attribute( tensorflow::ImportNumpy(); %} -// Basic types - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.resize(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - Py_DECREF(py_int); - Py_DECREF(o); - } - $1 = temps; -} - // Computation builder types %typemap(in) absl::Span( @@ -287,12 +158,12 @@ tensorflow::ImportNumpy(); // Computation and buffer/allocation types -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - auto* value = $1.ValueOrDie(); + xla::swig::LocalClient value = $1.ValueOrDie(); { - auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + auto $1 = value; + $typemap(out, xla::swig::LocalClient) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -300,12 +171,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledXrtComputation*) + $typemap(out, xla::swig::LocalExecutable*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -339,38 +210,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::XrtAllocation*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::XrtAllocationTuple*) - } - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(out) StatusOr { - if ($1.ok()) { - auto* value = $1.ValueOrDie(); - { - auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::Computation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -430,473 +275,6 @@ tensorflow::ImportNumpy(); $1 = temps; } -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - XrtAllocation* xrta; - if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; - } - temps.push_back(xrta); - Py_DECREF(o); - } - $1 = temps; -} - -// Literal - -%typemap(out) StatusOr { - if ($1.ok()) { - Literal value = $1.ConsumeValueOrDie(); - $result = numpy::PyObjectFromXlaLiteral(*value); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Literal& (StatusOr literal_status) { - literal_status = numpy::XlaLiteralFromPyObject($input); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - SWIG_fail; - } - $1 = &literal_status.ValueOrDie(); -} - -%typemap(out) Literal { - $result = numpy::PyObjectFromXlaLiteral(*$1); -} - -%typemap(out) StatusOr { - if (!$1.ok()) { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } - $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); - if (!literal_status.ok()) { - PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - Py_DECREF(o); - SWIG_fail; - } - temps.push_back(literal_status.ConsumeValueOrDie()); - Py_DECREF(o); - } - $1 = &temps; -} - -// OpMetadata - -%typemap(in) const OpMetadata& (OpMetadata temp) { - StatusOr statusor = numpy::OpMetadataFromPyObject($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -// Shape - -%typemap(out) const Shape& { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(out) StatusOr { - if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - -%typemap(in) const Shape& (Shape temp) { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; -} - -%typemap(in) const absl::optional& ( - absl::optional temp) { - if ($input == Py_None) { - temp = absl::nullopt; - $1 = &temp; - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape($input); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temp = std::move(statusor).ValueOrDie(); - $1 = &temp; - } -} - -%typemap(out) std::unique_ptr { - $result = numpy::PyShapeInfoFromXlaShape(*$1); -} - -%typemap(in) const std::vector& (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - $1 = &temps; -} - -%typemap(in) const std::vector >& ( - std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (o == Py_None) { - temps.push_back(absl::nullopt); - } else { - StatusOr statusor = numpy::XlaShapeFromPyShape(o); - Py_DECREF(o); - if (!statusor.ok()) { - PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - SWIG_fail; - } - temps.push_back(statusor.ConsumeValueOrDie()); - } - } - $1 = &temps; -} - -// PrimitiveType - -%typemap(in) PrimitiveType { - PyObject* py_int = numpy::PyNumberToPyInt($input); - if (!py_int) { - PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - SWIG_fail; - } - const long value = numpy::PyIntOrPyLongToLong(py_int); - if (value == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - SWIG_fail; - } - if (!PrimitiveType_IsValid(value)) { - PyErr_SetString( - PyExc_TypeError, "Argument not valid for PrimitiveType enum"); - Py_DECREF(py_int); - SWIG_fail; - } - $1 = static_cast(value); -} - -// Span> - -%typemap(in) absl::Span > - (std::vector > temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - if (!o) { - SWIG_fail; - } - PyObject* first = PyTuple_GetItem(o, 0); - if (!first) { - Py_DECREF(o); - SWIG_fail; - } - PyObject* first_pyint = numpy::PyNumberToPyInt(first); - if (!first_pyint) { - PyErr_SetString( - PyExc_TypeError, - "First pair item cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - PyObject* second = PyTuple_GetItem(o, 1); - if (!second) { - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - PyObject* second_pyint = numpy::PyNumberToPyInt(second); - if (!second_pyint) { - PyErr_SetString( - PyExc_TypeError, - "Second pair item cannot be converted to int"); - Py_DECREF(o); - Py_DECREF(first_pyint); - SWIG_fail; - } - const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); - if (first_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); - if (second_value == -1 && PyErr_Occurred()) { - Py_DECREF(o); - Py_DECREF(first_pyint); - Py_DECREF(second_pyint); - SWIG_fail; - } - temps.push_back(std::make_pair(first_value, second_value)); - Py_DECREF(o); - } - $1 = temps; -} - -// DotDimensionNumbers - -%typemap(in) const DotDimensionNumbers& - (DotDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "lhs_contracting_dimensions", - dimension_numbers.mutable_lhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_contracting_dimensions", - dimension_numbers.mutable_rhs_contracting_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "lhs_batch_dimensions", - dimension_numbers.mutable_lhs_batch_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "rhs_batch_dimensions", - dimension_numbers.mutable_rhs_batch_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// PaddingConfig - -%typemap(in) const PaddingConfig& - (PaddingConfig padding_config) { - PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); - if (!dimensions) { - SWIG_fail; - } - - int length = PySequence_Size(dimensions); - if (length == -1) { - Py_DECREF(dimensions); - SWIG_fail; - } - - for (int i = 0; i < length; ++i) { - PyObject* item = PySequence_GetItem(dimensions, i); - if (!item) { - Py_DECREF(dimensions); - SWIG_fail; - } - int64 edge_padding_low, edge_padding_high, interior_padding; - if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) - || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) - || !GetIntAttr(item, "interior_padding", &interior_padding)) { - Py_DECREF(item); - Py_DECREF(dimensions); - SWIG_fail; - } - Py_DECREF(item); - - PaddingConfig::PaddingConfigDimension* dimension = - padding_config.add_dimensions(); - dimension->set_edge_padding_low(edge_padding_low); - dimension->set_edge_padding_high(edge_padding_high); - dimension->set_interior_padding(interior_padding); - } - Py_DECREF(dimensions); - - $1 = &padding_config; -} - -// ConvolutionDimensionNumbers - -%typemap(in) const ConvolutionDimensionNumbers& - (ConvolutionDimensionNumbers dimension_numbers) { - int64 value; - - if (!GetIntAttr($input, "input_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_batch_dimension(value); - - if (!GetIntAttr($input, "input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_input_feature_dimension(value); - - if (!GetIntAttr($input, "output_batch_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_batch_dimension(value); - - if (!GetIntAttr($input, "output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_output_feature_dimension(value); - - if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - SWIG_fail; - } - dimension_numbers.set_kernel_input_feature_dimension(value); - - if (!HandleRepeatedInt64Attribute( - $input, "input_spatial_dimensions", - dimension_numbers.mutable_input_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "kernel_spatial_dimensions", - dimension_numbers.mutable_kernel_spatial_dimensions())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "output_spatial_dimensions", - dimension_numbers.mutable_output_spatial_dimensions())) { - SWIG_fail; - } - - $1 = &dimension_numbers; -} - -// GatherDimensionNumbers - -%typemap(in) const GatherDimensionNumbers& - (GatherDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "offset_dims", - dimension_numbers.mutable_offset_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "collapsed_slice_dims", - dimension_numbers.mutable_collapsed_slice_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "start_index_map", - dimension_numbers.mutable_start_index_map())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// ScatterDimensionNumbers - -%typemap(in) const ScatterDimensionNumbers& - (ScatterDimensionNumbers dimension_numbers) { - if (!HandleRepeatedInt64Attribute( - $input, "update_window_dims", - dimension_numbers.mutable_update_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "inserted_window_dims", - dimension_numbers.mutable_inserted_window_dims())) { - SWIG_fail; - } - if (!HandleRepeatedInt64Attribute( - $input, "scatter_dims_to_operand_dims", - dimension_numbers.mutable_scatter_dims_to_operand_dims())) { - SWIG_fail; - } - - int64 value; - if (!GetIntAttr($input, "index_vector_dim", &value)) { - SWIG_fail; - } - dimension_numbers.set_index_vector_dim(value); - - $1 = &dimension_numbers; -} - -// Span - -%typemap(in) absl::Span - (std::vector temps) { - if (!PySequence_Check($input)) { - PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - SWIG_fail; - } - const int size = PySequence_Size($input); - temps.reserve(size); - for (int i = 0; i < size; ++i) { - PyObject* o = PySequence_GetItem($input, i); - ReplicaGroup rgrp; - if (!HandleRepeatedInt64Attribute( - o, "replica_ids", - rgrp.mutable_replica_ids())) { - SWIG_fail; - } - temps.push_back(rgrp); - Py_DECREF(o); - } - $1 = temps; -} - - // ExecutableBuildOptions %typemap(in) const ExecutableBuildOptions* @@ -966,160 +344,151 @@ tensorflow::ImportNumpy(); %ignoreall %unignore xla; %unignore xla::swig; -%unignore xla::swig::InitializeReplicaCount; -%unignore xla::swig::InitializePlatformName; -%unignore xla::swig::GetReplicaCount; %unignore xla::swig::RegisterCpuCustomCallTarget; -%unignore xla::swig::TransferToInfeedLocal; -%unignore xla::swig::TransferToInfeedLocalReplica; -%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalClient; +%unignore xla::swig::LocalClient::Get; +%unignore xla::swig::LocalClient::DeviceCount; +%unignore xla::swig::LocalClient::TransferToInfeed; +%unignore xla::swig::LocalClient::TransferFromOutfeed; %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; %unignore xla::swig::LocalShapedBuffer::shape; +%unignore xla::swig::LocalShapedBuffer::DestructureTuple; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; -%unignore xla::swig::XrtAllocation; -%unignore xla::swig::XrtAllocation::FromLiteral; -%unignore xla::swig::XrtAllocation::ToLiteral; -%unignore xla::swig::XrtAllocation::shape; -%unignore xla::swig::XrtAllocationTuple; -%unignore xla::swig::XrtAllocationTuple::Release; -%unignore xla::swig::XrtAllocationTuple::size; -%unignore xla::swig::CompiledLocalComputation; -%unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; -%unignore xla::swig::CompiledXrtComputation; -%unignore xla::swig::CompiledXrtComputation::Execute; -%unignore xla::swig::LocalComputation; -%unignore xla::swig::LocalComputation::Compile; -%unignore xla::swig::LocalComputation::CompileForXrt; -%unignore xla::swig::LocalComputation::GetReturnValueShape; -%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalExecutable; +%unignore xla::swig::LocalExecutable::DeviceOrdinals; +%unignore xla::swig::LocalExecutable::Execute; +%unignore xla::swig::LocalExecutable::ExecutePerReplica; +%unignore xla::swig::Computation; +%unignore xla::swig::Computation::Compile; +%unignore xla::swig::Computation::GetProgramShape; +%unignore xla::swig::Computation::GetReturnValueShape; +%unignore xla::swig::Computation::GetSerializedProto; +%unignore xla::swig::Computation::GetHloText; +%unignore xla::swig::Computation::GetHloDotGraph; %unignore xla::swig::LocalOp; -%unignore xla::swig::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; -%unignore xla::swig::LocalComputationBuilder::Build; -%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; -%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; -%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; -%unignore xla::swig::LocalComputationBuilder::Parameter; -%unignore xla::swig::LocalComputationBuilder::GetShape; -%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; -%unignore xla::swig::LocalComputationBuilder::Infeed; -%unignore xla::swig::LocalComputationBuilder::Outfeed; -%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; -%unignore xla::swig::LocalComputationBuilder::ConstantR0; -%unignore xla::swig::LocalComputationBuilder::Iota; -%unignore xla::swig::LocalComputationBuilder::BroadcastedIota; -%unignore xla::swig::LocalComputationBuilder::Broadcast; -%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; -%unignore xla::swig::LocalComputationBuilder::Pad; -%unignore xla::swig::LocalComputationBuilder::Reshape; -%unignore xla::swig::LocalComputationBuilder::Collapse; -%unignore xla::swig::LocalComputationBuilder::AllToAll; -%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; -%unignore xla::swig::LocalComputationBuilder::Slice; -%unignore xla::swig::LocalComputationBuilder::SliceInDim; -%unignore xla::swig::LocalComputationBuilder::DynamicSlice; -%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; -%unignore xla::swig::LocalComputationBuilder::ConcatInDim; -%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::Select; -%unignore xla::swig::LocalComputationBuilder::Tuple; -%unignore xla::swig::LocalComputationBuilder::GetTupleElement; -%unignore xla::swig::LocalComputationBuilder::ConvertElementType; -%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; -%unignore xla::swig::LocalComputationBuilder::Call; -%unignore xla::swig::LocalComputationBuilder::Transpose; -%unignore xla::swig::LocalComputationBuilder::Rev; -%unignore xla::swig::LocalComputationBuilder::Clamp; -%unignore xla::swig::LocalComputationBuilder::Map; -%unignore xla::swig::LocalComputationBuilder::Reduce; -%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; -%unignore xla::swig::LocalComputationBuilder::RngNormal; -%unignore xla::swig::LocalComputationBuilder::RngUniform; -%unignore xla::swig::LocalComputationBuilder::RngBernoulli; -%unignore xla::swig::LocalComputationBuilder::While; -%unignore xla::swig::LocalComputationBuilder::Conditional; -%unignore xla::swig::LocalComputationBuilder::IsConstant; -%unignore xla::swig::LocalComputationBuilder::Eq; -%unignore xla::swig::LocalComputationBuilder::Ne; -%unignore xla::swig::LocalComputationBuilder::Ge; -%unignore xla::swig::LocalComputationBuilder::Gt; -%unignore xla::swig::LocalComputationBuilder::Lt; -%unignore xla::swig::LocalComputationBuilder::Le; -%unignore xla::swig::LocalComputationBuilder::Dot; -%unignore xla::swig::LocalComputationBuilder::DotGeneral; -%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; -%unignore xla::swig::LocalComputationBuilder::Add; -%unignore xla::swig::LocalComputationBuilder::Sub; -%unignore xla::swig::LocalComputationBuilder::Mul; -%unignore xla::swig::LocalComputationBuilder::Div; -%unignore xla::swig::LocalComputationBuilder::Rem; -%unignore xla::swig::LocalComputationBuilder::Max; -%unignore xla::swig::LocalComputationBuilder::Min; -%unignore xla::swig::LocalComputationBuilder::And; -%unignore xla::swig::LocalComputationBuilder::Or; -%unignore xla::swig::LocalComputationBuilder::Xor; -%unignore xla::swig::LocalComputationBuilder::ShiftLeft; -%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; -%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; -%unignore xla::swig::LocalComputationBuilder::Not; -%unignore xla::swig::LocalComputationBuilder::Abs; -%unignore xla::swig::LocalComputationBuilder::Exp; -%unignore xla::swig::LocalComputationBuilder::Expm1; -%unignore xla::swig::LocalComputationBuilder::Floor; -%unignore xla::swig::LocalComputationBuilder::Ceil; -%unignore xla::swig::LocalComputationBuilder::Round; -%unignore xla::swig::LocalComputationBuilder::Log; -%unignore xla::swig::LocalComputationBuilder::Log1p; -%unignore xla::swig::LocalComputationBuilder::Sign; -%unignore xla::swig::LocalComputationBuilder::Cos; -%unignore xla::swig::LocalComputationBuilder::Sin; -%unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::Atan2; -%unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::Pow; -%unignore xla::swig::LocalComputationBuilder::Neg; -%unignore xla::swig::LocalComputationBuilder::Sort; -%unignore xla::swig::LocalComputationBuilder::SortKeyVal; -%unignore xla::swig::LocalComputationBuilder::Sqrt; -%unignore xla::swig::LocalComputationBuilder::Rsqrt; -%unignore xla::swig::LocalComputationBuilder::Square; -%unignore xla::swig::LocalComputationBuilder::Reciprocal; -%unignore xla::swig::LocalComputationBuilder::Erfc; -%unignore xla::swig::LocalComputationBuilder::Erf; -%unignore xla::swig::LocalComputationBuilder::ErfInv; -%unignore xla::swig::LocalComputationBuilder::Lgamma; -%unignore xla::swig::LocalComputationBuilder::Digamma; -%unignore xla::swig::LocalComputationBuilder::Acos; -%unignore xla::swig::LocalComputationBuilder::Asin; -%unignore xla::swig::LocalComputationBuilder::Atan; -%unignore xla::swig::LocalComputationBuilder::Tan; -%unignore xla::swig::LocalComputationBuilder::Acosh; -%unignore xla::swig::LocalComputationBuilder::Asinh; -%unignore xla::swig::LocalComputationBuilder::Atanh; -%unignore xla::swig::LocalComputationBuilder::Cosh; -%unignore xla::swig::LocalComputationBuilder::Sinh; -%unignore xla::swig::LocalComputationBuilder::Real; -%unignore xla::swig::LocalComputationBuilder::Imag; -%unignore xla::swig::LocalComputationBuilder::Conj; -%unignore xla::swig::LocalComputationBuilder::Complex; -%unignore xla::swig::LocalComputationBuilder::Cholesky; -%unignore xla::swig::LocalComputationBuilder::QR; -%unignore xla::swig::LocalComputationBuilder::TriangularSolve; -%unignore xla::swig::LocalComputationBuilder::CustomCall; -%unignore xla::swig::LocalComputationBuilder::Gather; -%unignore xla::swig::LocalComputationBuilder::Scatter; -%unignore xla::swig::DeleteLocalComputation; -%unignore xla::swig::DestructureLocalShapedBufferTuple; -%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::ComputationBuilder; +%unignore xla::swig::ComputationBuilder::Build; +%unignore xla::swig::ComputationBuilder::BuildWithRoot; +%unignore xla::swig::ComputationBuilder::SetOpMetadata; +%unignore xla::swig::ComputationBuilder::ClearOpMetadata; +%unignore xla::swig::ComputationBuilder::Parameter; +%unignore xla::swig::ComputationBuilder::GetShape; +%unignore xla::swig::ComputationBuilder::GetReturnValueShape; +%unignore xla::swig::ComputationBuilder::Infeed; +%unignore xla::swig::ComputationBuilder::Outfeed; +%unignore xla::swig::ComputationBuilder::ConstantLiteral; +%unignore xla::swig::ComputationBuilder::ConstantR0; +%unignore xla::swig::ComputationBuilder::Iota; +%unignore xla::swig::ComputationBuilder::BroadcastedIota; +%unignore xla::swig::ComputationBuilder::Broadcast; +%unignore xla::swig::ComputationBuilder::BroadcastInDim; +%unignore xla::swig::ComputationBuilder::Pad; +%unignore xla::swig::ComputationBuilder::Reshape; +%unignore xla::swig::ComputationBuilder::Collapse; +%unignore xla::swig::ComputationBuilder::AllToAll; +%unignore xla::swig::ComputationBuilder::CrossReplicaSum; +%unignore xla::swig::ComputationBuilder::Slice; +%unignore xla::swig::ComputationBuilder::SliceInDim; +%unignore xla::swig::ComputationBuilder::DynamicSlice; +%unignore xla::swig::ComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::ComputationBuilder::ConcatInDim; +%unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::Select; +%unignore xla::swig::ComputationBuilder::Tuple; +%unignore xla::swig::ComputationBuilder::GetTupleElement; +%unignore xla::swig::ComputationBuilder::ConvertElementType; +%unignore xla::swig::ComputationBuilder::BitcastConvertType; +%unignore xla::swig::ComputationBuilder::Call; +%unignore xla::swig::ComputationBuilder::Transpose; +%unignore xla::swig::ComputationBuilder::Rev; +%unignore xla::swig::ComputationBuilder::Clamp; +%unignore xla::swig::ComputationBuilder::Map; +%unignore xla::swig::ComputationBuilder::Reduce; +%unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::ComputationBuilder::RngNormal; +%unignore xla::swig::ComputationBuilder::RngUniform; +%unignore xla::swig::ComputationBuilder::RngBernoulli; +%unignore xla::swig::ComputationBuilder::While; +%unignore xla::swig::ComputationBuilder::Conditional; +%unignore xla::swig::ComputationBuilder::IsConstant; +%unignore xla::swig::ComputationBuilder::Eq; +%unignore xla::swig::ComputationBuilder::Ne; +%unignore xla::swig::ComputationBuilder::Ge; +%unignore xla::swig::ComputationBuilder::Gt; +%unignore xla::swig::ComputationBuilder::Lt; +%unignore xla::swig::ComputationBuilder::Le; +%unignore xla::swig::ComputationBuilder::Dot; +%unignore xla::swig::ComputationBuilder::DotGeneral; +%unignore xla::swig::ComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::ComputationBuilder::Add; +%unignore xla::swig::ComputationBuilder::Sub; +%unignore xla::swig::ComputationBuilder::Mul; +%unignore xla::swig::ComputationBuilder::Div; +%unignore xla::swig::ComputationBuilder::Rem; +%unignore xla::swig::ComputationBuilder::Max; +%unignore xla::swig::ComputationBuilder::Min; +%unignore xla::swig::ComputationBuilder::And; +%unignore xla::swig::ComputationBuilder::Or; +%unignore xla::swig::ComputationBuilder::Xor; +%unignore xla::swig::ComputationBuilder::ShiftLeft; +%unignore xla::swig::ComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::ComputationBuilder::ShiftRightLogical; +%unignore xla::swig::ComputationBuilder::Not; +%unignore xla::swig::ComputationBuilder::Clz; +%unignore xla::swig::ComputationBuilder::Abs; +%unignore xla::swig::ComputationBuilder::Exp; +%unignore xla::swig::ComputationBuilder::Expm1; +%unignore xla::swig::ComputationBuilder::Floor; +%unignore xla::swig::ComputationBuilder::Ceil; +%unignore xla::swig::ComputationBuilder::Round; +%unignore xla::swig::ComputationBuilder::Log; +%unignore xla::swig::ComputationBuilder::Log1p; +%unignore xla::swig::ComputationBuilder::Sign; +%unignore xla::swig::ComputationBuilder::Cos; +%unignore xla::swig::ComputationBuilder::Sin; +%unignore xla::swig::ComputationBuilder::Tanh; +%unignore xla::swig::ComputationBuilder::Atan2; +%unignore xla::swig::ComputationBuilder::IsFinite; +%unignore xla::swig::ComputationBuilder::Pow; +%unignore xla::swig::ComputationBuilder::Neg; +%unignore xla::swig::ComputationBuilder::Sort; +%unignore xla::swig::ComputationBuilder::SortKeyVal; +%unignore xla::swig::ComputationBuilder::Sqrt; +%unignore xla::swig::ComputationBuilder::Rsqrt; +%unignore xla::swig::ComputationBuilder::Square; +%unignore xla::swig::ComputationBuilder::Reciprocal; +%unignore xla::swig::ComputationBuilder::Erfc; +%unignore xla::swig::ComputationBuilder::Erf; +%unignore xla::swig::ComputationBuilder::ErfInv; +%unignore xla::swig::ComputationBuilder::Lgamma; +%unignore xla::swig::ComputationBuilder::Digamma; +%unignore xla::swig::ComputationBuilder::Acos; +%unignore xla::swig::ComputationBuilder::Asin; +%unignore xla::swig::ComputationBuilder::Atan; +%unignore xla::swig::ComputationBuilder::Tan; +%unignore xla::swig::ComputationBuilder::Acosh; +%unignore xla::swig::ComputationBuilder::Asinh; +%unignore xla::swig::ComputationBuilder::Atanh; +%unignore xla::swig::ComputationBuilder::Cosh; +%unignore xla::swig::ComputationBuilder::Sinh; +%unignore xla::swig::ComputationBuilder::Real; +%unignore xla::swig::ComputationBuilder::Imag; +%unignore xla::swig::ComputationBuilder::Conj; +%unignore xla::swig::ComputationBuilder::Complex; +%unignore xla::swig::ComputationBuilder::Cholesky; +%unignore xla::swig::ComputationBuilder::QR; +%unignore xla::swig::ComputationBuilder::TriangularSolve; +%unignore xla::swig::ComputationBuilder::CustomCall; +%unignore xla::swig::ComputationBuilder::Gather; +%unignore xla::swig::ComputationBuilder::Scatter; +%unignore xla::swig::DeleteComputation; %unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteXrtAllocation; -%unignore xla::swig::DeleteCompiledLocalComputation; -%unignore xla::swig::DeleteCompiledXrtComputation; +%unignore xla::swig::DeleteLocalExecutable; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 52c5c621f7294c5da341879d15b77559fe870551..74f45b7cdcfd7d7b10a5832be37ac1fb34057743 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -26,6 +26,10 @@ namespace swig { namespace numpy { +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); +} + int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { switch (primitive_type) { case PRED: @@ -123,28 +127,42 @@ bool NumpyTypeIsValid(int np_type) { } } -PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); - PyObject* dimensions; + Safe_PyObjectPtr dimensions; if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); - dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); for (int i = 0; i < num_elements; ++i) { PyTuple_SET_ITEM( - dimensions, i, - PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + dimensions.get(), i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) + .release()); } } else { int rank = shape.rank(); - dimensions = PyTuple_New(rank); + dimensions = make_safe(PyTuple_New(rank)); for (int i = 0; i < rank; ++i) { - PyTuple_SET_ITEM(dimensions, i, + PyTuple_SET_ITEM(dimensions.get(), i, LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); } } - return PyTuple_Pack(2, np_dtype, dimensions); + return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); +} + +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape) { + Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); + for (int i = 0; i < shape.parameters_size(); ++i) { + PyTuple_SET_ITEM(arg_shapes.get(), i, + PyShapeInfoFromXlaShape(shape.parameters(i)).release()); + } + + Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); + return make_safe( + PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); } // Precondition: o->ob_type == &PyArrayDescr_Type @@ -349,13 +367,17 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); - PyObject* tuple = PyTuple_New(num_elements); + std::vector elems(num_elements); + for (int i = 0; i < num_elements; i++) { + TF_ASSIGN_OR_RETURN(elems[i], + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + } + Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM(tuple, i, - PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); } return tuple; } else { @@ -365,10 +387,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); } int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); - PyObject* array = - PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); - CopyLiteralToNumpyArray(np_type, literal, - reinterpret_cast(array)); + Safe_PyObjectPtr array = make_safe( + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); + TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( + np_type, literal, reinterpret_cast(array.get()))); return array; } } @@ -408,6 +430,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_BOOL: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_INT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_INT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -417,6 +445,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_UINT8: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_UINT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_UINT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -445,12 +476,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array) { +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_INT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_INT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -460,6 +497,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_UINT8: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_UINT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_UINT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -482,8 +522,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, CopyLiteralToNumpyArray(literal, py_array); break; default: - LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); } + return Status::OK(); } PyObject* LongToPyIntOrPyLong(long x) { // NOLINT @@ -525,6 +567,92 @@ PyObject* PyNumberToPyInt(PyObject* o) { } // namespace numpy +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } +#if PY_MAJOR_VERSION < 3 + if (!PyString_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); +#else + if (!PyBytes_Check(attr)) { + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyBytes_AsString(attr)); +#endif + + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field) { + PyObject* seq = PyObject_GetAttrString(o, attr_name); + if (!seq) { + return false; + } + + int length = PySequence_Size(seq); + if (length == -1) { + Py_DECREF(seq); + return false; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(seq, i); + if (!item) { + Py_DECREF(seq); + return false; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(seq); + return false; + } + *field->Add() = dimension; + Py_DECREF(item); + } + Py_DECREF(seq); + return true; +} + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 40ff2d9ad214cc4dcad42234fa296834cbc92882..eff8cda334f00050605febad66a61aa1c518c500 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -36,6 +36,16 @@ namespace swig { namespace numpy { +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +using Safe_PyObjectPtr = std::unique_ptr; + +Safe_PyObjectPtr make_safe(PyObject* object); + // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and // vice versa. @@ -54,7 +64,13 @@ bool NumpyTypeIsValid(int np_type); // providing the array dimensions. // // The return value is a new reference. -PyObject* PyShapeInfoFromXlaShape(const Shape& shape); +Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple +// of argument shapes and result_shape is the result shape. Each shape is as +// described in in PyShapeInfoFromXlaShape's comment. +Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( + const ProgramShape& shape); // Converts a Python object with a method interface mathing that of // xla_client.Shape into an XLA Shape object. @@ -74,7 +90,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,8 +106,8 @@ StatusOr XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array); +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array); template void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { @@ -120,6 +136,18 @@ PyObject* PyNumberToPyInt(PyObject* o); } // namespace numpy +// Miscellaneous swig helpers that don't have a better home. + +bool GetIntAttr(PyObject* o, const char* field, int64* result); + +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, const char* attr_name, + std::function f); + +bool HandleRepeatedInt64Attribute( + PyObject* o, const char* attr_name, + tensorflow::protobuf::RepeatedField* field); + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds index bce6c1acf8a1cc0005ca93e0466c5a0e29d880de..ef77ed3d95850fdfc7145e6fe1df4833d20bb7df 100644 --- a/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds +++ b/tensorflow/compiler/xla/python/pywrap_xla_exported_symbols.lds @@ -1 +1,2 @@ _PyInit__pywrap_xla +_init_pywrap_xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 37cae0e3b6b8635ca53e282994f0d078974df5a9..9019a979a61c6ebb62adaa5503560c604e2b30f8 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An in-process, local XLA client in Python, supporting AOT compilation.""" +"""An XLA client in Python, supporting AOT compilation.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import enum # pylint: disable=g-bad-import-order import inspect @@ -33,13 +34,32 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 +# Import the XRT backend, if available. +try: + # pylint: disable=g-import-not-at-top + from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api +except ImportError: + xrt_api = None + # Most functions are snake_case for consistency with other modules, whereas -# method names of ComputationBuilder and LocalComputation are CamelCase for +# method names of ComputationBuilder and Computation are CamelCase for # consistency with XLA. # pylint: disable=invalid-name +# Version of the XLA Python client. +# +# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is +# packaged separately from the Python code that consumes it (jax). +# +# We occasionally need to make backwards-incompatible changes to jaxlib, in +# which case we need to be able to detect when incompatible versions are +# installed. +def version(): + return (0, 1, 8) + + _OP_METADATA_FIELDS = [ 'op_type', 'op_name', @@ -49,13 +69,163 @@ _OP_METADATA_FIELDS = [ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) +@six.add_metaclass(abc.ABCMeta) +class Backend(object): + """Abstract base class for XLA backends.""" + + @abc.abstractmethod + def device_count(self): + """Returns the number of devices known to the backend.""" + + @abc.abstractmethod + def buffer_from_pyval(self, pyval, device=0): + """Allocates a fresh buffer and populates it with `pyval`.""" + + @abc.abstractmethod + def delete_buffer(self, c_buffer): + """Deletes buffer `c_buffer`.""" + + @abc.abstractmethod + def destructure_tuple(self, c_buffer): + """Destructures a tuple buffer into a sequence of buffers.""" + + @abc.abstractmethod + def compile(self, computation, argument_shapes, result_shape, + compile_options): + """Compiles a computation. Returns an executable.""" + + @abc.abstractmethod + def delete_executable(self, executable): + """Deletes an executable.""" + + @abc.abstractmethod + def execute(self, executable, args): + """Runs an executable without replication.""" + + @abc.abstractmethod + def execute_replicated(self, executable, per_replica_args): + """Runs an executable in a replicated manner.""" + + +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + +class XlaLocalBackend(Backend): + """XLA backend implemented using the in-process xla::LocalClient API.""" + + def __init__(self, platform=None): + platform = platform or _get_default_platform_name() + self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) + self._delete_buffer = c_api.DeleteLocalShapedBuffer + self._delete_executable = c_api.DeleteLocalExecutable + + def device_count(self): + return self.client.DeviceCount() + + def buffer_from_pyval(self, pyval, device=0): + return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = c_buffer.DestructureTuple() + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return c_computation.Compile(argument_shapes, compile_options, self.client) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + output_buffer_tup = executable.ExecutePerReplica(per_replica_args) + size = output_buffer_tup.size() + return [output_buffer_tup.Release(i) for i in xrange(size)] + + +class XrtBackend(Backend): + """XLA backend implemented using XRT.""" + + def __init__(self, target): + self.target = target + self._delete_buffer = xrt_api.DeleteXrtAllocation + self._delete_executable = xrt_api.DeleteXrtExecutable + + def device_count(self): + return 1 # Multidevice execution not implemented. + + def buffer_from_pyval(self, pyval, device=0): + if device != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return xrt_api.XrtAllocation.FromLiteral(pyval, + _maybe_encode_string(self.target)) + + def delete_buffer(self, c_buffer): + self._delete_buffer(c_buffer) + + def destructure_tuple(self, c_buffer): + result = xrt_api.DestructureXrtAllocationTuple( + c_buffer, _maybe_encode_string(self.target)) + return [result.Release(i) for i in xrange(result.size())] + + def compile(self, c_computation, argument_shapes, result_shape, + compile_options): + return xrt_api.XrtExecutable.CompileForXrt( + c_computation.GetSerializedProto(), argument_shapes, result_shape, + _maybe_encode_string(self.target)) + + def delete_executable(self, executable): + self._delete_executable(executable) + + def execute(self, executable, args): + return executable.Execute(args) + + def execute_replicated(self, executable, per_replica_args): + if len(per_replica_args) != 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + return [executable.Execute(per_replica_args[0])] + + +_default_platform_name = 'Host' +_default_backend = None + + +def _get_default_platform_name(): + return _default_platform_name + + +def _get_default_local_backend(): + global _default_backend + global _default_platform_name + if _default_backend is None: + _default_backend = XlaLocalBackend(_default_platform_name) + return _default_backend + + class BackendType(enum.Enum): XLA_LOCAL = 1 XRT = 2 -BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target')) -XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local') +def BackendSpec(backend, target): + """Compatibility wrapper to support older clients. Do not use in new code.""" + if backend == BackendType.XLA_LOCAL: + return _get_default_local_backend() + elif backend == BackendType.XRT: + return XrtBackend(target) + else: + raise ValueError('Unknown backend {}'.format(backend)) def OpMetadataToProto(pyobj): @@ -78,13 +248,6 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) -def _maybe_encode_string(s): - if six.PY3: - return s.encode('utf-8') - else: - return s - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -122,6 +285,7 @@ def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, _UNARY_OPS = [ 'Not', + 'Clz', 'Abs', 'Exp', 'Expm1', @@ -223,33 +387,18 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_buffer, backend, replica): + def __init__(self, c_buffer, backend, device): self.c_buffer = c_buffer self._backend = backend - self._replica = replica - if backend.backend_type == BackendType.XRT: - self._delete = c_api.DeleteXrtAllocation - else: - self._delete = c_api.DeleteLocalShapedBuffer + self._device = device @staticmethod - def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + def from_pyval(pyval, device=0, backend=None): """Allocate and copy to XLA the given python value.""" + backend = backend or _get_default_local_backend() pyval = require_numpy_array_layout(pyval) - num_replicas = get_replica_count() - if not 0 <= replica < num_replicas: - raise ValueError( - 'Attempt to place buffer on replica {} when the replica count is {}' - .format(replica, num_replicas)) - if backend.backend_type == BackendType.XRT: - if replica != 0: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - cbuf = c_api.XrtAllocation.FromLiteral( - pyval, _maybe_encode_string(backend.target)) - else: - cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) - return LocalBuffer(cbuf, backend, replica) + cbuf = backend.buffer_from_pyval(pyval, device) + return LocalBuffer(cbuf, backend, device) def to_py(self): return self.c_buffer.ToLiteral() @@ -257,29 +406,22 @@ class LocalBuffer(object): def shape(self): return _wrap_shape(self.c_buffer.shape()) - def replica(self): - return self._replica + def device(self): + return self._device def delete(self): if self.c_buffer is not None: - self._delete(self.c_buffer) + self._backend.delete_buffer(self.c_buffer) self.c_buffer = None def destructure(self): """Assuming a tuple buffer, unpack it into constituent tuple elements.""" assert self.c_buffer is not None - if self._backend.backend_type == BackendType.XRT: - result = c_api.DestructureXrtAllocationTuple( - self.c_buffer, _maybe_encode_string(self._backend.target)) - else: - result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) + result = self._backend.destructure_tuple(self.c_buffer) self.delete() - size = result.size() - destructured = tuple( - LocalBuffer( - result.Release(i), replica=self._replica, backend=self._backend) - for i in xrange(size)) - return destructured + return tuple( + LocalBuffer(sub_buffer, device=self._device, backend=self._backend) + for sub_buffer in result) def is_deleted(self): return self.c_buffer is None @@ -428,6 +570,34 @@ class Shape(object): updated._check_minor_to_major() # pylint: disable=protected-access return updated + def with_major_to_minor_layout_if_absent(self): + """Returns a copy of a shape with missing layouts set to major-to-minor.""" + + def f(a): + if a.minor_to_major(): + return None + return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) + + return self.map_leaves(f) + + def serialize(self, proto): + """Serializes 'shape' into proto.""" + if self.is_tuple(): + proto.element_type = xla_data_pb2.TUPLE + for shape in self.tuple_shapes(): + shape.serialize(proto.tuple_shapes.add()) + else: + proto.element_type = dtype_to_etype(self.element_type()) + proto.dimensions.extend(self.dimensions()) + proto.is_dynamic_dimension.extend([False for _ in self.dimensions()]) + if self.minor_to_major(): + proto.layout.format = xla_data_pb2.DENSE + proto.layout.minor_to_major.extend(self.minor_to_major()) + + +ProgramShape = collections.namedtuple('ProgramShape', + ('parameter_shapes', 'result_shape')) + def _wrap_shape(shape_info): dtype, dims = shape_info @@ -439,6 +609,12 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) +def _wrap_program_shape(shape_info): + arg_shapes, result_shape = shape_info + return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], + _wrap_shape(result_shape)) + + def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -462,7 +638,7 @@ class CompileOptions(object): self.num_replicas = get_replica_count() -def transfer_to_infeed(value, replica_number=None): +def transfer_to_infeed(value, device_ordinal=0): """Transfers the given value into the XLA infeed queue. XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with @@ -472,64 +648,50 @@ def transfer_to_infeed(value, replica_number=None): Args: value: the value that the caller would like to enqueue into the XLA infeed queue - replica_number: the replica number to infeed the value to -- if not - provided, then the default replica (trivially replica 0) is used. + device_ordinal: the device to infeed the value to. Each device has a + distinct infeed queue. """ - if replica_number is None: - c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) - else: - c_api.TransferToInfeedLocalReplica( - require_numpy_array_layout(value), replica_number) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + backend.client.TransferToInfeed( + require_numpy_array_layout(value), device_ordinal) -def transfer_from_outfeed(shape, replica_number=None): - """Transfers a literal of the given shape from replica_number's outfeed. +def transfer_from_outfeed(shape, device_ordinal=0): + """Transfers a literal of the given shape from `device_ordinal`'s outfeed. Args: shape: The shape of the value to transfer from outfeed. - replica_number: The replica number ordinal to transfer the outfeed value - from. (Each replica has a distinct outfeed queue.) + device_ordinal: The device ordinal to transfer the outfeed value from. Each + device has a distinct outfeed queue.. Returns: The literal value that is produced from the outfeed queue. """ - return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + # TODO(phawkins): support non-default backends. + backend = _get_default_local_backend() + return backend.client.TransferFromOutfeed(shape, device_ordinal) -class LocalComputation(object): - """Python wrapper for a local XLA Computation. +class Computation(object): + """Python wrapper for an XLA Computation. - A LocalComputation can be executed if it is compiled. Otherwise, it - can still be used as a Computation where required by the - ComputationBuilder methods. + A Computation can be compiled to form an Executable, or used as a + subcomputation in ComputationBuilder methods. """ - def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + def __init__(self, c_computation, backend=None): self._c_computation = c_computation + # The backend argument is deprecated. Pass a backend to Compile() instead. self._backend = backend - self._is_compiled = is_compiled - - # Ensure a reference to C-based destructor for use in __del__. - if is_compiled: - if backend.backend_type == BackendType.XRT: - assert isinstance(c_computation, c_api.CompiledXrtComputation) - self._delete = c_api.DeleteCompiledXrtComputation - else: - assert isinstance(c_computation, c_api.CompiledLocalComputation) - self._delete = c_api.DeleteCompiledLocalComputation - else: - assert isinstance(c_computation, c_api.LocalComputation) - self._delete = c_api.DeleteLocalComputation + self._delete_computation = c_api.DeleteComputation @property def computation(self): - if self._is_compiled: - raise ValueError( - 'Attempt to read the XLA computation of a compiled LocalComputation.') return self._c_computation def GetProto(self): - """Get the HloModuleProto proto object in this local computation. + """Get the HloModuleProto proto object in this computation. Returns: An HloModuleProto proto object that has the whole-graph information. @@ -538,30 +700,41 @@ class LocalComputation(object): proto = hlo_pb2.HloModuleProto.FromString(serialized) return proto - def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): - """Compiles an un-compiled local computation. + def GetHloText(self): + """Get the textual HLO representation of this computation. + + Returns: + A string containing the textual HLO. + """ + return self.computation.GetHloText() + + def GetHloDotGraph(self): + """Get a Graphviz Dot representation of this computation. + + Returns: + A string containing the graphviz dot graph. + """ + return self.computation.GetHloDotGraph() - Local computations are the result of a "LocalComputationBuild'ing" process - -- they start in uncompiled form, and via a call to Compile() turn into a - compiled local computation. + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, + backend=None): + """Compiles a computation. - Raises: - ValueError: if this is already a compiled local computation. + Computations are the result of a "ComputationBuild'ing" process. Arguments: argument_shapes: parameter shapes -- they are first laid out by layout_fn if layout_fn is provided. Otherwise, the default layout for those shapes will be used. - compile_options: options to use for compilation, includes an optional - laid out result shape for the computation. + compile_options: options to use for compilation, includes an optional laid + out result shape for the computation. layout_fn: lambda that is used to lay out the argument/result shapes. + backend: a `Backend` for which an executable should be generated. Returns: - A newly *compiled* local computation instance. + A Executable instance. """ - if self._is_compiled: - raise ValueError('Attempt to compile a compiled local XLA computation.') - + backend = backend or self._backend or _get_default_local_backend() result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: @@ -574,32 +747,52 @@ class LocalComputation(object): compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - if self._backend.backend_type == BackendType.XRT: - c = self.computation.CompileForXrt( - argument_shapes, _maybe_encode_string(self._backend.target)) - else: - c = self.computation.Compile(argument_shapes, compile_options) - return LocalComputation(c, is_compiled=True, backend=self._backend) + c = backend.compile(self.computation, argument_shapes, result_shape, + compile_options) + return Executable(c, backend=backend) def CompileWithExampleArguments(self, arguments=(), compile_options=None, - layout_fn=None): + layout_fn=None, + backend=None): return self.Compile( argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, - layout_fn=layout_fn) + layout_fn=layout_fn, + backend=backend) + + def GetProgramShape(self): + return _wrap_program_shape(self._c_computation.GetProgramShape()) def GetReturnValueShape(self): return _wrap_shape(self._c_computation.GetReturnValueShape()) + def __del__(self): + if self._c_computation: + self._delete_computation(self._c_computation) + + +class Executable(object): + """Python wrapper for an XLA Executable.""" + + def __init__(self, c_executable, backend=None): + self._c_executable = c_executable + self._device_ordinals = c_executable.DeviceOrdinals() + self._backend = backend + + def DeviceOrdinals(self): + """Returns a list containing the device ordinals for each replica.""" + return self._device_ordinals + def Execute(self, arguments=(), check_for_deleted_args=True): """Execute on one replica with LocalBuffer arguments and return value.""" if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): raise ValueError('Executing with deleted local buffer argument') raw_args = [arg.c_buffer for arg in arguments] - output_buffer = self._c_computation.Execute(raw_args) - return LocalBuffer(output_buffer, backend=self._backend, replica=0) + output_buffer = self._backend.execute(self._c_executable, raw_args) + return LocalBuffer( + output_buffer, backend=self._backend, device=self._device_ordinals[0]) def ExecutePerReplica(self, arguments=None): """Execute on many replicas with LocalBuffer arguments and return value. @@ -609,14 +802,12 @@ class LocalComputation(object): sequence comprises the arguments for execution on the i'th replica. Returns: - A list of the computation's outputs on each replica, as a LocalBuffer. If + A list of the computation's outputs for each replica, as a LocalBuffer. If a shallow sequence of arguments was passed in for `arguments`, then the sole, zero'th replica's output is returned instead, as a LocalBuffer. """ - if not self._is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') if arguments is None: - arguments = ((),) * get_replica_count() + arguments = ((),) * len(self._device_ordinals) else: arguments = [list(replica_args) for replica_args in arguments] @@ -625,37 +816,35 @@ class LocalComputation(object): for arg in replica_args: if arg.is_deleted(): raise ValueError('Executing with deleted local buffer argument') - if arg.replica() != replica: + if arg.device() != self._device_ordinals[replica]: raise ValueError( - 'Executing on replica {} with argument from replica {}'.format( - replica, arg.replica())) + 'Executing on device {} with argument from device {}'.format( + self._device_ordinals[replica], arg.device())) # Pull out argument buffer handles + # pylint: disable=g-complex-comprehension stripped_args = [ [arg.c_buffer for arg in replica_args] for replica_args in arguments ] # Execute - if self._backend.backend_type == BackendType.XRT: - if len(stripped_args) > 1: - raise NotImplementedError( - 'Multi-replica execution is not yet supported via the XRT backend.') - output_buffers = [self._c_computation.Execute(stripped_args[0])] - else: - output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) - size = output_buffer_tup.size() - output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] + output_buffers = self._backend.execute_replicated(self._c_executable, + stripped_args) # Wrap output handles in LocalBuffer instances return tuple( - LocalBuffer(output_buffer, backend=self._backend, replica=replica) + LocalBuffer( + output_buffer, + backend=self._backend, + device=self._device_ordinals[replica]) for replica, output_buffer in enumerate(output_buffers)) def ExecuteWithPythonValues(self, arguments=()): """Execute on one replica with Python values as arguments and output.""" def put(arg): - return LocalBuffer.from_pyval(arg, backend=self._backend) + return LocalBuffer.from_pyval( + arg, device=self._device_ordinals[0], backend=self._backend) arguments = [put(arg) for arg in arguments] return self.Execute(arguments).to_py() @@ -663,16 +852,19 @@ class LocalComputation(object): def ExecuteWithPythonValuesPerReplica(self, arguments): """Execute on many replicas with Python values as arguments and output.""" - def put(arg, replica): - return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + def put(arg, device): + return LocalBuffer.from_pyval(arg, device, backend=self._backend) - arguments = [[put(arg, replica) - for arg in replica_args] - for replica, replica_args in enumerate(arguments)] + # pylint: disable=g-complex-comprehension + arguments = [[ + put(arg, self._device_ordinals[replica]) for arg in replica_args + ] for replica, replica_args in enumerate(arguments)] return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): - self._delete(self._c_computation) + # Python may have freed c_api first. + if c_api and self._c_executable: + self._backend.delete_executable(self._c_executable) def _make_replica_group_proto(replica_group): @@ -685,8 +877,8 @@ class ComputationBuilder(object): """XLA computation builder. Enqueues XLA ops in sequence and in order to build a - LocalComputation, which in turn can be compiled into a - CompiledLocalComputation, which in turn can be locally executed. + Computation, which in turn can be compiled into a + LocalExecutable, which in turn can be locally executed. """ # The methods of this class map 1-to-1 onto the XLA C++ @@ -697,16 +889,23 @@ class ComputationBuilder(object): # pylint: disable=g-doc-args def __init__(self, name): - self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._client = c_api.ComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + def Build(self, root=None, backend=None): + """Builds a `Computation` from the contents of the builder. + + Args: + root: if not None, the operator containing the return value of the + computation. + backend: deprecated. Pass a `backend` to `Computation.Compile` instead. + Returns: + A `Computation`. + """ if root is not None: - return LocalComputation( - self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + return Computation(self._client.BuildWithRoot(root), backend=backend) else: - return LocalComputation( - self._client.Build(), is_compiled=False, backend=backend) + return Computation(self._client.Build(), backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -1358,7 +1557,7 @@ class ComputationBuilder(object): Args: operand: a LocalOp to test. - Returns: a LocalComputation that is rooted on the given `operand` which is a + Returns: a Computation that is rooted on the given `operand` which is a compile-time constant. """ return self._client.BuildConstantSubGraph(operand) @@ -1523,11 +1722,23 @@ class ComputationBuilder(object): """Enqueues a QR decomposition onto the computation.""" return self._client.QR(a, full_matrices) - def TriangularSolve(self, a, b, left_side=False, lower=False, - transpose_a=False, conjugate_a=False): + def TriangularSolve(self, + a, + b, + left_side=False, + lower=False, + transpose_a=False, + conjugate_a=False, + unit_diagonal=False): """Enqueues a triangular-solve operation onto the computation.""" - return self._client.TriangularSolve( - a, b, left_side, lower, transpose_a, conjugate_a) + if not transpose_a: + transpose = 1 + if conjugate_a: + a = self.Conj(a) + else: + transpose = 3 if conjugate_a else 2 + return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal, + transpose) def Gather(self, a, start_indices, dimension_numbers, slice_sizes): """Enqueues a Gather operation onto the computation.""" @@ -1547,7 +1758,7 @@ def _forward_methods_to_local_builder(): Set up methods, corresponding to unary and binary XLA operations, whose calls are forwarded in a boilerplate manner to the underlying - LocalComputationBuilder C-extension API. + ComputationBuilder C-extension API. """ def forward_to_local_builder_with_handles(target_method, is_binop=False): @@ -1567,13 +1778,13 @@ def _forward_methods_to_local_builder(): for method_name in _UNARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name)) + getattr(c_api.ComputationBuilder, method_name)) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) for method_name in _BINARY_OPS: forward = forward_to_local_builder_with_handles( - getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + getattr(c_api.ComputationBuilder, method_name), is_binop=True) forward.__name__ = method_name setattr(ComputationBuilder, method_name, forward) @@ -1581,8 +1792,14 @@ def _forward_methods_to_local_builder(): _forward_methods_to_local_builder() +_default_replica_count = 1 + + def initialize_replica_count(replica_count): - """Initializes the desired replica count to use on XLA service init. + """Initializes the default replica count to use. + + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. Args: replica_count: number of replicas that are desired for set up during XLA @@ -1591,29 +1808,30 @@ def initialize_replica_count(replica_count): Raises: A runtime exception if the XLA service has already been initialized. """ - c_api.InitializeReplicaCount(replica_count) + global _default_replica_count + _default_replica_count = replica_count -def initialize_platform_name(platform_name): - """Initializes the desired platform name to use on XLA service init. - - Args: - platform_name: string name of platform. +def get_replica_count(): + """Returns the default replica count. - Raises: - A runtime exception if the XLA service has already been initialized. + Deprecated; pass `num_replicas` as an option to `Computation.Compile()` + instead. """ - platform_name = _maybe_encode_string(platform_name) - c_api.InitializePlatformName(platform_name) + return _default_replica_count -def get_replica_count(): - """Returns the current replica count used for the XLA service. +def initialize_platform_name(platform_name): + """Initializes the default platform name to use for XLA. - Note: this will return a value whether the XLA service has been initialized - yet or not. + Args: + platform_name: string name of platform. """ - return c_api.GetReplicaCount() + global _default_platform_name + _default_platform_name = platform_name + + # Make sure the platform is valid by trying to instantiate it. + _get_default_local_backend() def register_cpu_custom_call_target(name, fn): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c80e792464560f4722b657694d8eb6f5e03956a9..51ef7d7f3a17f341e955f48615b05a886813430b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -29,7 +29,7 @@ from tensorflow.compiler.xla.python import xla_client import unittest -class LocalComputationTest(unittest.TestCase): +class ComputationTest(unittest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -85,9 +85,35 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationsWithConstantsTest(LocalComputationTest): +class ComputationPrinting(unittest.TestCase): + + def ExampleComputation(self): + builder = xla_client.ComputationBuilder("acomputation") + p0 = builder.ParameterFromNumpy(np.float32(0)) + p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) + builder.Mul(p0, p1) + return builder.Build() + + def testComputationToHloText(self): + computation = self.ExampleComputation() + hlo_text = computation.GetHloText() + self.assertTrue(hlo_text.startswith("HloModule acomputation")) + + def testComputationToHloGraph(self): + computation = self.ExampleComputation() + hlo_dot_graph = computation.GetHloDotGraph() + self.assertTrue(hlo_dot_graph.startswith("digraph ")) + + +class ComputationsWithConstantsTest(ComputationTest): """Tests focusing on Constant ops.""" + def testConstantScalarSumS8(self): + c = self._NewComputation() + root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + self._ExecuteAndCompareExact(c, expected=np.int8(3)) + def testConstantScalarSumF32(self): c = self._NewComputation() root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) @@ -298,7 +324,7 @@ class ComputationsWithConstantsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=0.75) -class ParametersTest(LocalComputationTest): +class ParametersTest(ComputationTest): """Tests focusing on Parameter ops and argument-passing.""" def setUp(self): @@ -378,7 +404,7 @@ class ParametersTest(LocalComputationTest): expected=[-4.3, 1.3, -6.3, 3.3]) -class LocalBufferTest(LocalComputationTest): +class LocalBufferTest(ComputationTest): """Tests focusing on execution with LocalBuffers.""" def _Execute(self, c, arguments): @@ -476,7 +502,7 @@ class LocalBufferTest(LocalComputationTest): self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) -class SingleOpTest(LocalComputationTest): +class SingleOpTest(ComputationTest): """Tests for single ops. The goal here is smoke testing - to exercise the most basic functionality of @@ -751,6 +777,12 @@ class SingleOpTest(LocalComputationTest): c.Not(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=~arr) + def testCountLeadingZeros(self): + c = self._NewComputation() + arr = NumpyArrayS32([0x7FFF, 0x12345678]) + c.Clz(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=[17, 3]) + def testExp(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -1169,7 +1201,7 @@ class SingleOpTest(LocalComputationTest): np.testing.assert_allclose(g, expected, rtol=1e-4) -class EmbeddedComputationsTest(LocalComputationTest): +class EmbeddedComputationsTest(ComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" def _CreateConstantS32Computation(self): @@ -1633,7 +1665,7 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose(c, expected=expected) -class ErrorTest(LocalComputationTest): +class ErrorTest(ComputationTest): def setUp(self): self.f32_scalar_2 = NumpyArrayF32(2.0) @@ -1650,7 +1682,7 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) -class ComputationRootTest(LocalComputationTest): +class ComputationRootTest(ComputationTest): """Tests related to setting the root of the computation.""" def testComputationRootDifferentFromLastOp(self): diff --git a/tensorflow/compiler/xla/python/xla_data.i b/tensorflow/compiler/xla/python/xla_data.i new file mode 100644 index 0000000000000000000000000000000000000000..974f314af24f61c0015a8d51c16dff1bfc84c7cc --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_data.i @@ -0,0 +1,654 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// xla_data.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// Span <- sequence of int +// vector -> sequence of int +// Span <- sequence of LocalOp +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// ProgramShape -> pair of ([arg_shapes], ret_shape) +// std::vector <- sequence of xla_client.Shape objects +// PrimitiveType <- int +// Span> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto +// ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto +// GatherDimensionNumbers proto <- corresponding Python proto +// ScatterDimensionNumbers proto <- corresponding Python proto +// Span <- sequence of ReplicaGroup Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// Shapes output by C++ become Python objects with the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. +// +// Other SWIG object wrappers (e.g. of Computation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%module(threads="1") xla_data + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Basic types + + +%typemap(out) std::vector { + PyObject* out = PyList_New($1.size()); + for (int i = 0; i < $1.size(); ++i) { + PyList_SET_ITEM(out, i, PyInt_FromLong($1[i])); + } + $result = out; +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyString_FromString($1.ConsumeValueOrDie().c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + SWIG_fail; + } + Py_INCREF(Py_None); + $result = Py_None; +} + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + SWIG_fail; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + SWIG_fail; + } + $1 = &literal_status.ValueOrDie(); +} + +%typemap(out) Literal (StatusOr obj_status) { + obj_status = numpy::PyObjectFromXlaLiteral(*$1); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(out) StatusOr (StatusOr obj_status) { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } + obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + SWIG_fail; + } + temps.push_back(literal_status.ConsumeValueOrDie()); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyProgramShapeInfoFromXlaProgramShape( + $1.ConsumeValueOrDie()).release(); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) const Shape& (Shape temp) { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +%typemap(in) const absl::optional& ( + absl::optional temp) { + if ($input == Py_None) { + temp = absl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1).release(); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(absl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + SWIG_fail; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + SWIG_fail; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + SWIG_fail; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + SWIG_fail; + } + $1 = static_cast(value); +} + +// Span> + +%typemap(in) absl::Span > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + SWIG_fail; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + SWIG_fail; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + SWIG_fail; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + SWIG_fail; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "lhs_contracting_dimensions", + dimension_numbers.mutable_lhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_contracting_dimensions", + dimension_numbers.mutable_rhs_contracting_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "lhs_batch_dimensions", + dimension_numbers.mutable_lhs_batch_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "rhs_batch_dimensions", + dimension_numbers.mutable_rhs_batch_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + SWIG_fail; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + SWIG_fail; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + SWIG_fail; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + SWIG_fail; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + SWIG_fail; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + if (!HandleRepeatedInt64Attribute( + $input, "input_spatial_dimensions", + dimension_numbers.mutable_input_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "kernel_spatial_dimensions", + dimension_numbers.mutable_kernel_spatial_dimensions())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "output_spatial_dimensions", + dimension_numbers.mutable_output_spatial_dimensions())) { + SWIG_fail; + } + + $1 = &dimension_numbers; +} + +// GatherDimensionNumbers + +%typemap(in) const GatherDimensionNumbers& + (GatherDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "offset_dims", + dimension_numbers.mutable_offset_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "collapsed_slice_dims", + dimension_numbers.mutable_collapsed_slice_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "start_index_map", + dimension_numbers.mutable_start_index_map())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// ScatterDimensionNumbers + +%typemap(in) const ScatterDimensionNumbers& + (ScatterDimensionNumbers dimension_numbers) { + if (!HandleRepeatedInt64Attribute( + $input, "update_window_dims", + dimension_numbers.mutable_update_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "inserted_window_dims", + dimension_numbers.mutable_inserted_window_dims())) { + SWIG_fail; + } + if (!HandleRepeatedInt64Attribute( + $input, "scatter_dims_to_operand_dims", + dimension_numbers.mutable_scatter_dims_to_operand_dims())) { + SWIG_fail; + } + + int64 value; + if (!GetIntAttr($input, "index_vector_dim", &value)) { + SWIG_fail; + } + dimension_numbers.set_index_vector_dim(value); + + $1 = &dimension_numbers; +} + +// Span + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + ReplicaGroup rgrp; + if (!HandleRepeatedInt64Attribute( + o, "replica_ids", + rgrp.mutable_replica_ids())) { + SWIG_fail; + } + temps.push_back(rgrp); + Py_DECREF(o); + } + $1 = temps; +} diff --git a/tensorflow/compiler/xla/python/xrt.cc b/tensorflow/compiler/xla/python/xrt.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c55abc17f87c369e3d5b2140a84014e07921a9a --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.cc @@ -0,0 +1,297 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/xrt.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace swig { + +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} + +XrtAllocation::~XrtAllocation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +/* static */ +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { + xrt::XLAAllocation alloc; + *alloc.mutable_value() = argument.ToProto(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto literal_string = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({literal_string, alloc.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtAllocation(handle, argument.shape(), session_target); +} + +const int64 XrtAllocation::handle() const { return handle_; } + +const Shape& XrtAllocation::shape() const { return shape_; } + +StatusOr XrtAllocation::ToLiteral() const { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + + xla::LiteralProto response; + TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); + return Literal::CreateFromProto(response); +} + +XrtAllocationTuple::XrtAllocationTuple(std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + CHECK(element != nullptr); + } +} + +XrtAllocationTuple::~XrtAllocationTuple() { + for (XrtAllocation* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr XrtAllocationTuple::Release(int i) { + XrtAllocation* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int64 XrtAllocationTuple::size() const { return elements_.size(); } + +StatusOr XrtExecutable::CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto compile = tensorflow::ops::XRTCompile(root, program); + TF_RETURN_IF_ERROR(root.status()); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + ProgramShape program_shape; + for (auto& shape : argument_shapes) { + *program_shape.add_parameters() = shape; + } + *program_shape.mutable_result() = result_shape; + + LayoutUtil::SetToDefaultLayout(&program_shape); + *config->mutable_program_shape() = program_shape.ToProto(); + c.mutable_hlo_snapshot() + ->mutable_hlo() + ->mutable_hlo_module() + ->ParsePartialFromString(hlo_module_proto); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({program, c.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtExecutable(program_shape, handle, session_target); +} + +XrtExecutable::XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} + +XrtExecutable::~XrtExecutable() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({computation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} + +StatusOr XrtExecutable::Execute( + absl::Span argument_handles) { + const int num_expected_arguments = program_shape().parameters().size(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + std::vector arguments; + arguments.reserve(num_expected_arguments); + for (int i = 0; i < num_expected_arguments; ++i) { + arguments.push_back( + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); + } + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto execution_config = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto execute = tensorflow::ops::XRTExecute(root, computation_handle, + execution_config, arguments); + TF_RETURN_IF_ERROR(root.status()); + + TF_RET_CHECK(argument_handles.size() == arguments.size()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + for (int i = 0; i < arguments.size(); ++i) { + inputs.insert({arguments[i], argument_handles[i]->handle()}); + } + inputs.insert({computation_handle, handle()}); + inputs.insert({execution_config, e.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); + + int64 output = outputs[0].scalar()(); + return new XrtAllocation(output, program_shape().result(), session_target_); +} + +const ProgramShape& XrtExecutable::program_shape() const { + return program_shape_; +} + +int64 XrtExecutable::handle() const { return handle_; } + +void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } + +void DeleteXrtExecutable(XrtExecutable* computation) { delete computation; } + +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target) { + const Shape& tuple_shape = allocation->shape(); + + if (!tuple_shape.IsTuple()) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); + auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + std::vector results; + for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + inputs.clear(); + inputs.insert({base_handle, allocation->handle()}); + inputs.insert({shape_index, {i}}); + std::vector outputs; + auto status = session.Run(inputs, {subtuple}, &outputs); + if (!status.ok()) { + // Clean up before returning non-ok status. + for (int j = 0; j < results.size(); ++j) { + delete results[j]; + } + return status; + } + const int64 subtuple_handle = outputs[0].scalar()(); + const Shape& subtuple_shape = + ShapeUtil::GetTupleElementShape(tuple_shape, i); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); + } + return new XrtAllocationTuple(std::move(results)); +} + +} // namespace swig +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xrt.h b/tensorflow/compiler/xla/python/xrt.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5bba6d5c9641dadc323f70745e870c14543321 --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.h @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape.h" + +namespace xla { +namespace swig { + +// Represents a reference to literals that live in a device-allocated buffer via +// XRT. Specifically, wraps an int64 handle produced by running the allocation +// graph, and an XLA shape to track the referent's shape. +class XrtAllocation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); + + XrtAllocation(int64 handle, Shape shape, const string& session_target); + ~XrtAllocation(); + StatusOr ToLiteral() const; + const Shape& shape() const; + const int64 handle() const; + + private: + const int64 handle_; + const Shape shape_; + const string session_target_; +}; + +// Result of a tuple destructuring operation on an XrtAllocation. +class XrtAllocationTuple { + public: + // Note: any XrtAllocation elements that are not Release()'d will be + // deallocated in the destructor. + explicit XrtAllocationTuple(std::vector elements); + + ~XrtAllocationTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int64 size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued XrtAllocation into its constitutent elements +// in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target); + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XRT computation handle. +class XrtExecutable { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. + static StatusOr CompileForXrt( + const string& hlo_module_proto, const std::vector& argument_shapes, + const Shape& result_shape, const string& session_target); + + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + XrtExecutable(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~XrtExecutable(); + + std::vector DeviceOrdinals() const { return {0}; } + + StatusOr Execute( + absl::Span argument_handles); + + const ProgramShape& program_shape() const; + int64 handle() const; + + private: + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; +}; + +// Functions for freeing resources from the Python side. +void DeleteXrtAllocation(XrtAllocation* allocation); +void DeleteXrtExecutable(XrtExecutable* computation); + +} // namespace swig +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_XRT_H_ diff --git a/tensorflow/compiler/xla/python/xrt.i b/tensorflow/compiler/xla/python/xrt.i new file mode 100644 index 0000000000000000000000000000000000000000..456dd7be86e479b46815fc16b51a10431fe2060d --- /dev/null +++ b/tensorflow/compiler/xla/python/xrt.i @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Wrappers for XRT ops. + +%module(threads="1") xrt + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" +%include "tensorflow/compiler/xla/python/xla_data.i" + +%{ +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/python/xrt.h" + +using namespace xla; +using namespace xla::swig; + +%} + +// Computation and buffer/allocation types + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtExecutable*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocationTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + XrtAllocation* xrta; + if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; + } + temps.push_back(xrta); + Py_DECREF(o); + } + $1 = temps; +} + + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::XrtAllocation; +%unignore xla::swig::XrtAllocation::FromLiteral; +%unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; +%unignore xla::swig::XrtAllocationTuple; +%unignore xla::swig::XrtAllocationTuple::Release; +%unignore xla::swig::XrtAllocationTuple::size; +%unignore xla::swig::XrtExecutable; +%unignore xla::swig::XrtExecutable::CompileForXrt; +%unignore xla::swig::XrtExecutable::DeviceOrdinals; +%unignore xla::swig::XrtExecutable::Execute; +%unignore xla::swig::DestructureXrtAllocationTuple; +%unignore xla::swig::DeleteXrtAllocation; +%unignore xla::swig::DeleteXrtExecutable; + +%thread; +%include "tensorflow/compiler/xla/python/xrt.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4f6509c1cb9dddac3f90cb8bea9b8ee989e4da4b..8d8394cb43ee013b9396a54e3a4d037445fcc0e1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -114,6 +114,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_creation_utils", ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -679,7 +680,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -1203,7 +1203,6 @@ cc_library( ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -1461,11 +1460,15 @@ cc_library( hdrs = ["hlo_creation_utils.h"], deps = [ ":hlo", + ":hlo_module_config", ":shape_inference", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1510,6 +1513,20 @@ cc_library( ], ) +cc_library( + name = "op_expander_pass", + srcs = ["op_expander_pass.cc"], + hdrs = ["op_expander_pass.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "gather_expander", srcs = ["gather_expander.cc"], @@ -1518,6 +1535,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -1541,6 +1559,28 @@ cc_library( ], ) +cc_library( + name = "triangular_solve_expander", + srcs = ["triangular_solve_expander.cc"], + hdrs = ["triangular_solve_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1602,7 +1642,7 @@ tf_cc_test( ":algebraic_simplifier", ":hlo", ":hlo_casting_utils", - ":hlo_matchers", + ":hlo_creation_utils", ":hlo_parser", ":hlo_pass", ":pattern_matcher", @@ -2163,6 +2203,8 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", @@ -2306,6 +2348,7 @@ tf_cc_test( srcs = ["hlo_dataflow_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_matchers", @@ -2476,6 +2519,7 @@ tf_cc_test( srcs = ["tuple_points_to_analysis_test.cc"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", @@ -2851,7 +2895,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -3029,8 +3072,6 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", ], ) @@ -3192,32 +3233,6 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_tfgraph_builder", - srcs = ["hlo_tfgraph_builder.cc"], - hdrs = ["hlo_tfgraph_builder.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "hlo_tfgraph_builder_test", - srcs = ["hlo_tfgraph_builder_test.cc"], - deps = [ - ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "hlo_graph_dumper", srcs = [ @@ -3229,7 +3244,6 @@ cc_library( ":hlo", ":hlo_casting_utils", ":hlo_execution_profile", - ":hlo_tfgraph_builder", ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -3273,7 +3287,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -3489,6 +3502,37 @@ tf_cc_test( ], ) +cc_library( + name = "stable_sort_expander", + srcs = ["stable_sort_expander.cc"], + hdrs = ["stable_sort_expander.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "stable_sort_expander_test", + srcs = ["stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_matchers", + ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + ":stable_sort_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_util", srcs = ["tuple_util.cc"], @@ -3585,7 +3629,6 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", ], @@ -3641,7 +3684,6 @@ cc_library( ":hlo_evaluator", ":hlo_pass", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -3804,7 +3846,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/container:flat_hash_map", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index cd06cfcdd38d56a43def8a531fb7f018b22ed888..bd17e96106abd9de0dd3bbf418439b0fb3edb746 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -280,15 +280,51 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { hlo)); } - // Helper method to perform and add reduction in a single dimension. - HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the msot major, + // and the contracting dimensions are most minor. + StatusOr NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions) { + std::vector transpose_dimensions(batch_dimensions.begin(), + batch_dimensions.end()); + for (int64 i = 0; i < dot_operand->shape().rank(); ++i) { + if (!(absl::c_linear_search(batch_dimensions, i) || + absl::c_linear_search(contracting_dimensions, i))) { + transpose_dimensions.push_back(i); + } + } + transpose_dimensions.insert(transpose_dimensions.end(), + contracting_dimensions.begin(), + contracting_dimensions.end()); + return MakeTransposeHlo(dot_operand, transpose_dimensions); + } + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); - Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + Shape shape = ShapeUtil::FilterDimensions( + [&](int64 dim) { return !absl::c_linear_search(dims, dim); }, + hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( - shape, hlo, zero, {dim}, AddReduce_computation)); + shape, hlo, zero, dims, AddReduce_computation)); + } + + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + return AddReduce(hlo, std::vector{dim}); } // Convenience method for replacing an instruction with a bitcast. If operand @@ -892,7 +928,6 @@ std::unique_ptr TryDivideToShift(HloInstruction* divide, } // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - Shape* shape; HloInstruction *a, *b, *c, *d; CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A @@ -955,6 +990,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { break; } + Shape* shape; // exp(A)/exp(B) => exp(A-B) if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) .WithShape(m::Shape(&shape)))) { @@ -1005,8 +1041,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - Literal new_literal(b->shape()); - switch (b->shape().element_type()) { + Shape result_shape = b->literal().shape(); + Literal new_literal(result_shape); + switch (result_shape.element_type()) { case F16: TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); break; @@ -1089,7 +1126,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( const int64 rhs_rank = rhs->shape().rank(); const int64 lhs_rank = lhs->shape().rank(); const auto& dnums = dot->dot_dimension_numbers(); - if (dnums.rhs_contracting_dimensions_size() > 1) { + if (dnums.rhs_contracting_dimensions_size() != 1) { return false; } if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { @@ -1119,16 +1156,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( std::swap(rhs_collapsing_dim, rhs_kept_dim); } - auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - return computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo)); - }; - auto reshape_if_necessary = [&](HloInstruction* hlo) { - hlo = as_type(hlo, dot->shape().element_type()); + hlo = AsType(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); @@ -1137,7 +1166,7 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { - return AddReduce(as_type(hlo, F32), dim); + return AddReduce(AsType(hlo, F32), dim); }; auto broadcast = [&](HloInstruction* hlo, const Shape& shape, @@ -1246,8 +1275,8 @@ StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( return dims; }; - // If the contracting dimension is 1, remove the degnerate dimnesions from the - // lhs and rhs, broadcast each to the result shape and multiply. + // If the contracting dimension is 1, remove the degnerate dimnensions from + // the lhs and rhs, broadcast each to the result shape and multiply. if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 && (rhs_kept_dim == rhs_rank - 1 || (rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) { @@ -1584,7 +1613,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); - + if (options_.is_layout_sensitive()) { + return Status::OK(); + } // Replace a zero element dot with a broadcast of the constant 0. if (ShapeUtil::IsZeroElementArray(dot->shape()) || ShapeUtil::IsZeroElementArray(lhs->shape()) || @@ -1601,6 +1632,117 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot->shape().element_type() != BF16) { return Status::OK(); } + + // If there are no contracting dimensions, a dot can be rewritten as + // mul(broadcast(transpose(x)),broadcast(transpose(y))) + if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + if (dot->shape().rank() != lhs->shape().rank()) { + std::vector lhs_broadcast_dims(lhs->shape().rank()); + absl::c_iota(lhs_broadcast_dims, 0); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_lhs, lhs_broadcast_dims)); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + if (dot->shape().rank() != rhs->shape().rank()) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { + rhs_broadcast_dims.push_back(i); + } + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_rhs, rhs_broadcast_dims)); + } + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + new_lhs, new_rhs)); + } + + // If the lhs or rhs have only batch and contracting dimensions, a dot can be + // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) + if ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == + lhs->shape().rank()) || + (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size() == + rhs->shape().rank())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_lhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + lhs, + AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * new_rhs, + NormalizeDotOperandToBatchMajorAndContractingMinor( + rhs, + AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), + AsInt64Slice( + dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + + int64 lhs_outer_dims = + lhs->shape().rank() - + (dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + int64 rhs_outer_dims = + rhs->shape().rank() - + (dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + dot->dot_dimension_numbers().rhs_contracting_dimensions_size()); + CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0); + if (rhs_outer_dims > 0) { + std::vector lhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(lhs_broadcast_dims, 0); + lhs_broadcast_dims.resize(lhs->shape().rank()); + std::iota(lhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().lhs_batch_dimensions_size(), + lhs_broadcast_dims.end(), + dot->dot_dimension_numbers().lhs_batch_dimensions_size() + + rhs_outer_dims); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_rhs->shape(), new_lhs, lhs_broadcast_dims)); + } else if (lhs_outer_dims > 0) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().rhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + rhs_broadcast_dims.resize(rhs->shape().rank()); + std::iota(rhs_broadcast_dims.begin() + + dot->dot_dimension_numbers().rhs_batch_dimensions_size(), + rhs_broadcast_dims.end(), + dot->dot_dimension_numbers().rhs_batch_dimensions_size() + + lhs_outer_dims); + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + new_lhs->shape(), new_rhs, rhs_broadcast_dims)); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); + std::vector reduce_dims( + dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); + new_dot = AsType(new_dot, F32); + const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); + absl::c_iota( + reduce_dims, + outer_dims + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + new_dot = AddReduce(new_dot, reduce_dims); + new_dot = AsType(new_dot, dot->shape().element_type()); + return ReplaceInstruction(dot, new_dot); + } + if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || dot->shape().rank() > 2) { if (options_.enable_dot_strength_reduction() && @@ -1639,7 +1781,11 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + if (dot->dot_dimension_numbers().lhs_batch_dimensions_size() == 0 && + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 1 && + dot->dot_dimension_numbers().lhs_contracting_dimensions(0) == 1 && + dot->dot_dimension_numbers().rhs_contracting_dimensions(0) == 0 && + lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); @@ -2529,11 +2675,11 @@ StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( int64 start = slice->slice_starts(i); int64 low = padding_config.dimensions(i).edge_padding_low(); int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (start < low || start >= low + data) { + return true; } } - return true; + return false; }(); if (in_padding) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f55a1886b8f86af4893c8a7fc18ed935d223eca0..af03fcb100813e8942efcaefc296b971c01a6aaa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -46,7 +46,6 @@ namespace { using ::testing::ElementsAre; namespace m = match; -namespace op = xla::testing::opcode_matchers; class AlgebraicSimplifierTest : public HloTestBase { protected: @@ -2749,12 +2748,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, &builder, + module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -2763,6 +2764,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); @@ -2772,10 +2774,11 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { HloInstruction::CreateParameter(1, values_shape, "values0")); auto values1 = builder.AddInstruction( HloInstruction::CreateParameter(2, values_shape, "values1")); - builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, - keys, {values0, values1})); - auto module = CreateNewVerifiedModule(); + TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape( + {keys_shape, values_shape, values_shape}), + {keys, values0, values1}, 0, /*is_stable=*/false, + &builder, module.get()) + .status()); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); @@ -3711,8 +3714,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); @@ -3957,7 +3960,7 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { param = f32[3,4] parameter(0) constant = f32[] constant(0.0) pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 - ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[4:5]} } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -3968,6 +3971,27 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalarConstant) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { const char* hlo_string = R"( HloModule module @@ -3989,6 +4013,29 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { EXPECT_THAT(root, GmockMatch(m::Parameter())); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadSomeDimsInPadding) { + const char* hlo_string = R"( + HloModule module + + ENTRY entry () -> f32[1]{0} { + constant.val = f32[] constant(4) + constant.pad = f32[] constant(-7) + reshape.1 = f32[1,1,1]{2,1,0} reshape(f32[] constant.val) + pad = f32[3,3,3]{2,1,0} pad(f32[1,1,1]{2,1,0} reshape.1, f32[] constant.pad), padding=0_2x0_2x2_0 + slice = f32[1,1,1]{2,1,0} slice(f32[3,3,3]{2,1,0} pad), slice={[0:1], [0:1], [0:1]} + ROOT reshape.2 = f32[1]{0} reshape(f32[1,1,1]{2,1,0} slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::ConstantScalar(-7.0)))); +} + TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { const char* hlo_string = R"( HloModule module @@ -4219,10 +4266,24 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { int m, k, n; PrimitiveType element_type; std::tie(m, k, n, element_type) = GetParam(); - - Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); - Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + std::vector lhs_dims = {1, 3, 5}; + std::vector rhs_dims = lhs_dims; + std::vector output_dims = lhs_dims; + if (m > 0) { + lhs_dims.push_back(m); + output_dims.push_back(m); + } + if (k > 0) { + lhs_dims.push_back(k); + rhs_dims.push_back(k); + } + if (n > 0) { + rhs_dims.push_back(n); + output_dims.push_back(n); + } + Shape dot_shape = ShapeUtil::MakeShape(element_type, output_dims); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, lhs_dims); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, rhs_dims); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction( @@ -4236,16 +4297,18 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(2); - dot_dnums.add_lhs_contracting_dimensions(4); - dot_dnums.add_rhs_contracting_dimensions(3); + if (k > 0) { + dot_dnums.add_lhs_contracting_dimensions(m > 0 ? 4 : 3); + dot_dnums.add_rhs_contracting_dimensions(3); + } builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); - const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; - const bool computation_should_be_modified = dot_should_be_transformed; - EXPECT_EQ(changed, computation_should_be_modified); + const bool dot_should_be_transformed = + m == 1 || k == 1 || n == 1 || m == -1 || k == -1 || n == -1; + EXPECT_EQ(changed, dot_should_be_transformed); bool has_no_dot = true; for (const auto& hlo : computation->instructions()) { if (hlo->opcode() == HloOpcode::kDot) { @@ -4256,10 +4319,12 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { EXPECT_EQ(has_no_dot, dot_should_be_transformed); } -INSTANTIATE_TEST_SUITE_P( - BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), - ::testing::Values(1, 2), ::testing::Values(F32, BF16))); +INSTANTIATE_TEST_SUITE_P(BatchDotStrengthReductionTestInstantiation, + BatchDotStrengthReductionTest, + ::testing::Combine(::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(-1, 1, 2), + ::testing::Values(F32, BF16))); class DotStrengthReductionTest : public AlgebraicSimplifierTest, @@ -4789,7 +4854,29 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {}); + shape.clear_layout(); + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + + HloInstruction* const_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(20.0f))); + builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + param, const_value)); + + std::unique_ptr module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Multiply())); } } // namespace diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index f8dff6a700cc9d5843053e3d451a7b005539ca26..52d6982c70f7962ea9f54db0a4b1f2089a122c1c 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -29,19 +29,16 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace { - namespace m = match; // Checks if the argument instruction is an AllReduce, followed by a certain // sequence of instructions and then a CRS. It must be possible to move // the AR past each instruction in the sequence. Returns the CRS, which is the // last instruction in the sequence. -absl::optional MatchesArCrsPattern( +absl::optional ArCrsCombiner::MatchesArCrsPattern( HloInstruction* instruction) { auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool { if (instruction->user_count() != 1) { @@ -78,23 +75,23 @@ absl::optional MatchesArCrsPattern( return absl::nullopt; } auto next = instruction->users()[0]; + int64 distance = 1; while (!next->IsCrossReplicaAllReduce()) { if (can_ar_move_past_instruction(next)) { next = next->users()[0]; } else { return absl::nullopt; } + ++distance; } if (!Cast(next)->IsNoop() && computation_is_addition(next->called_computations()[0])) { - return absl::optional(next); + return absl::optional(ArCrsPair(instruction, next, distance)); } else { return absl::nullopt; } } -} // namespace - absl::optional ArCrsCombiner::WhileFromBodyParameter( HloInstruction* instruction) { CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); @@ -236,15 +233,55 @@ bool ArCrsCombiner::InstructionsComputeSameValue( } void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + // Say that two or more ARs lead to the same CRS: (AR1, CRS), (AR2, CRS), + // ... , (ARn, CRS). + // If as we traverse the HLO graph we start tracking the pair (AR2, CRS), + // and later find that AR1's distance from the CRS is longer, we discard + // AR2 and start tracking AR1. We put the discarded ids in this set, in order + // to skip processing of short paths when we encounter the other ARs that + // have the same id as AR2. + absl::flat_hash_set discarded_ar_ids; for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { - auto maybe_crs = MatchesArCrsPattern(instruction); - if (maybe_crs) { - auto crs = *maybe_crs; + auto maybe_pair = MatchesArCrsPattern(instruction); + if (maybe_pair) { + auto pair = *maybe_pair; int64 ar_id = *(instruction->all_reduce_id()); - if (crs_reserved_map_.find(crs) == crs_reserved_map_.end()) { - all_reduce_map_[ar_id].push_back(instruction); - crs_reserved_map_[crs] = ar_id; + if (discarded_ar_ids.find(ar_id) != discarded_ar_ids.end()) { + continue; + } + auto it = crs_reserved_map_.find(pair.crs); + if (it != crs_reserved_map_.end()) { + auto prev_ar_id = it->second; + // Since there is another AR paired with CRS, + // all_reduce_map_[prev_ar_id] should exist, but + // all_reduce_map_[ar_id] shouldn't. + CHECK(all_reduce_map_.find(ar_id) == all_reduce_map_.end()); + CHECK_NE(prev_ar_id, ar_id); + auto prev_pair = all_reduce_map_[prev_ar_id].back(); + int64 prev_distance = prev_pair.distance; + if (prev_distance < pair.distance) { + // The current AR's distance to CRS is longer than the previously + // tracked AR, so we discard the previous AR. + all_reduce_map_.erase(prev_ar_id); + discarded_ar_ids.insert(prev_ar_id); + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; + } else { + // Discard the current AR id because we are keeping the previously + // tracked AR. + discarded_ar_ids.insert(ar_id); + } + } else { + if (all_reduce_map_.find(ar_id) != all_reduce_map_.end()) { + int64 prev_distance = all_reduce_map_[ar_id].back().distance; + CHECK_EQ(prev_distance, pair.distance) + << "All ARs with the same AR ID must have the same distance " + "from the corresponding CRSs. Found: " + << prev_distance << " and " << pair.distance; + } + all_reduce_map_[ar_id].push_back(pair); + crs_reserved_map_[pair.crs] = ar_id; } } } @@ -254,11 +291,11 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) { void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { for (auto it : all_reduce_map_) { auto all_reduce_id = it.first; - auto instruction_vec = it.second; - CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); - auto instr_0 = instruction_vec[0]; - for (int i = 1; i < instruction_vec.size(); ++i) { - auto instr_i = instruction_vec[i]; + auto pairs_vec = it.second; + CHECK_EQ(pairs_vec.size(), num_spatial_partitions_); + auto instr_0 = pairs_vec[0].ar; + for (int i = 1; i < pairs_vec.size(); ++i) { + auto instr_i = pairs_vec[i].ar; auto next_0 = instr_0->users()[0]; auto next_i = instr_i->users()[0]; absl::flat_hash_map visited_pairs; @@ -282,8 +319,9 @@ StatusOr ArCrsCombiner::RewriteGraph() { return false; } for (auto it : all_reduce_map_) { - auto instruction_vec = it.second; - for (auto all_reduce : instruction_vec) { + auto pairs_vec = it.second; + for (auto pair : pairs_vec) { + auto all_reduce = pair.ar; auto parent_computation = all_reduce->parent(); auto all_reduce_id = all_reduce->all_reduce_id(); auto prev = all_reduce->mutable_operand(0); @@ -304,16 +342,23 @@ StatusOr ArCrsCombiner::RewriteGraph() { ? next->operands()[1] : next->operands()[0]; // To move the AR past the addition/subtraction, we need to divide - // other_operand by the number of spatial partitions. - auto shape = other_operand->shape(); - Literal lit(shape); - lit.PopulateWithValue(num_spatial_partitions_); - auto divisor = parent_computation->AddInstruction( - HloInstruction::CreateConstant(lit.Clone())); - auto division = - parent_computation->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDivide, other_operand, divisor)); - TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + // other_operand by the number of spatial partitions, except if + // other_operand is a cross-module AR, which can be eliminated. + if (other_operand->IsCrossModuleAllReduce() && + other_operand->user_count() == 1) { + TF_CHECK_OK(other_operand->ReplaceAllUsesWith( + other_operand->mutable_operand(0))); + } else { + auto shape = other_operand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = parent_computation->AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDivide, + other_operand, divisor)); + TF_CHECK_OK(other_operand->ReplaceUseWith(next, division)); + } break; } default: diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h index e61ef5d4f9072979a6c356a9456c91e19405b01e..f503e1d5f2b519687e40818a61f0c0be9dfd3ab0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.h +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -26,11 +26,47 @@ limitations under the License. namespace xla { // When the HLO graph contains a cross-module AllReduce, followed by some simple -// linear operations, followed by a cross-replica AllReduce, we can combine the -// CMAR and the CRAR, to use an efficient AllReduce implementation that fully -// utilizes the interconnect bandwidth. +// linear operations, followed by a cross-replica AllReduce (also known as +// cross-replica sum, or CRS), we can combine the CMAR and the CRAR, to use an +// efficient AllReduce implementation that fully utilizes the interconnect +// bandwidth. // Such sequences appear in spatially partitioned models. -// This pass must run right after spatial partitioning. +// This pass must run right after spatial partitioning, when the code is still +// in a single HLO module. +// +// The steps are: +// 1) Find CMARs followed by simple ops followed by CRARs. +// 2) Group CMARs by all_reduce_id. They must all be rewritten. +// 3) Prove that the CMAR patterns in each core produce the same result. +// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the +// other operand by the number of spatial partitions. +// 5) Turn the CRAR into an all-core AllReduce. +// +// The pass also handles the case where multiple CMARs lead to the same CRAR, +// and eliminates all CMARs. This graph: +// +// Y +// | +// X CMAR_2 Z +// | \ / +// CMAR_1 + +// \ / +// + +// | +// CRAR +// +// gets rewritten to: +// +// Z num_partitions +// \ / +// Y div +// \ / +// X + +// \ / +// + +// | +// all-core AR +// class ArCrsCombiner : public HloModulePass { public: ArCrsCombiner(int num_spatial_partitions) @@ -43,6 +79,28 @@ class ArCrsCombiner : public HloModulePass { HloInstruction* i2); private: + // We used this struct because multiple ARs could be paired with the same CRS. + // In this case, we want to select the AR that is furthest from the CRS, + // because it makes it easier to eliminate all ARs during RewriteGraph. + struct ArCrsPair { + HloInstruction* ar; + HloInstruction* crs; + // The length of the path from AR to CRS in the HLO graph. + int64 distance; + + ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, + int64 dist) + : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} + + string ToString() { + return absl::StrCat("(AR: ", ar->name(), ", CRS: ", crs->name(), + ", distance: ", distance, ")"); + } + }; + + absl::optional MatchesArCrsPattern( + HloInstruction* instruction); + // If the passed instruction is a while parameter, and the while body is only // called by a single while instruction, return the while instruction. absl::optional WhileFromBodyParameter( @@ -80,8 +138,8 @@ class ArCrsCombiner : public HloModulePass { int num_spatial_partitions_; - // Map from all-reduce ids to the all reduce instructions. - absl::flat_hash_map> all_reduce_map_; + // Map from all-reduce ids to the AR/CRS pairs. + absl::flat_hash_map> all_reduce_map_; // Map from a CRS instruction to the all-reduce ID of the AR paired with the // CRS. Sometimes, several ARs in the code could be paired with the same CRS. diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc index 5152f0dc884a153f9b0ade06acd479832d87ff25..9c9db74fd2fdab836f91d2f749d08ad93f8879b0 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -1005,11 +1005,11 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { op::Tuple(op::AllReduce(op::Add( op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), - op::Divide(op::AllReduce(), op::Constant()))), + op::Parameter())), op::AllReduce(op::Add( op::Add(op::Parameter(), op::Divide(op::Constant(), op::Constant())), - op::Divide(op::AllReduce(), op::Constant()))))); + op::Parameter())))); auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); @@ -1093,15 +1093,17 @@ ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) { ArCrsCombiner combiner(2); auto changed = combiner.Run(module.get()).ValueOrDie(); EXPECT_TRUE(changed); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::AllReduce(op::Add( - op::Parameter(), - op::Divide(op::Add(op::AllReduce(), op::Constant()), - op::Constant()))), - op::AllReduce(op::Add( - op::Parameter(), - op::Divide(op::Add(op::AllReduce(), op::Constant()), - op::Constant()))))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))), + op::AllReduce(op::Add( + op::Parameter(), + op::Add(op::Parameter(), + op::Divide(op::Constant(), op::Constant())))))); + auto crs_after = module->entry_computation()->root_instruction()->operands()[0]; auto replica_groups_after = crs_after->replica_groups(); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 215e8ced4bb3f98a26ac4eb9912a7fd4d917852f..d016d3e03d5e994841b81cda6214b6ff7cb550be 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/byte_order.h" @@ -67,18 +66,38 @@ const absl::optional>& BackendOptions::allowed_devices() const { return allowed_devices_; } +namespace { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(tensorflow::thread::ThreadPool* pool) + : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + tensorflow::thread::ThreadPool* pool_ = nullptr; +}; + +} // namespace + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. -struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper(const int num_threads) +struct Backend::IntraOpThreadPool { + explicit IntraOpThreadPool(const int num_threads) : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), "XLAEigen", num_threads)), - wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + wrapper(new EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} std::unique_ptr pool; - std::unique_ptr wrapper; + std::unique_ptr wrapper; std::unique_ptr device; }; @@ -146,8 +165,7 @@ Backend::Backend(se::Platform* platform, Compiler* compiler, const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); - intra_op_thread_pool_wrapper_.reset( - new EigenThreadPoolWrapper(num_threads)); + intra_op_thread_pool_.reset(new IntraOpThreadPool(num_threads)); } } @@ -159,17 +177,17 @@ int Backend::default_device_ordinal() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->device.get(); + return intra_op_thread_pool_->device.get(); } tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { - if (intra_op_thread_pool_wrapper_ == nullptr) { + if (intra_op_thread_pool_ == nullptr) { return nullptr; } - return intra_op_thread_pool_wrapper_->pool.get(); + return intra_op_thread_pool_->pool.get(); } StatusOr Backend::stream_executor( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index c35f033dc0180409ae3888c2050021da83f5c72a..e7f29a044b95015aa7e547373c24971646833280 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -156,7 +156,6 @@ class Backend { Status ResetDevices(); private: - struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, absl::Span stream_executors, TransferManager* transfer_manager, @@ -183,7 +182,8 @@ class Backend { std::unique_ptr memory_allocator_; // For the CPU backend, an Eigen threadpool device for use by Eigen code. - std::unique_ptr intra_op_thread_pool_wrapper_; + struct IntraOpThreadPool; + std::unique_ptr intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index eda026ac5685dc469a6230094eb28b3618e36400..dbabd82dd55465dd4c85a56aea849a3e3702d6bf 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -28,6 +28,13 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( *rhs = batch_dot->mutable_operand(1); const Shape& lhs_shape = lhs->shape(); + // A dot with no contracting dims will be rewritten into a multiply by + // AlgebraicSimplifier. Dots with multiple contracting dims are currently + // unsupported. + if (dim_numbers.lhs_contracting_dimensions_size() != 1) { + return false; + } + std::vector degenerate_dims; for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { if (lhs_shape.dimensions(batch_dim) == 1) { diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 52ec1a794c5e9f4452a4bf2b648f453d8acfe976..a81f394a38f091b89b7f1e4d26653ff549f35b75 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -169,5 +169,47 @@ main { /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); } +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsNonContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,101] parameter(0) + b = f32[1,101] parameter(1) + ROOT dot = f32[1,101,101] dot(a,b), lhs_batch_dims={0}, + lhs_contracting_dims={}, + rhs_batch_dims={0}, + rhs_contracting_dims={} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDimsMultipleContracting) { + const char* hlo_text = R"( +HloModule BatchDot + +main { + lhs = f32[1,5,17,10,13] parameter(0) + rhs = f32[1,9,10,13,6,5] parameter(1) + ROOT dot = f32[10,1,17,9,6] dot(lhs,rhs), lhs_batch_dims={3,0}, + rhs_batch_dims={2,0}, + lhs_contracting_dims={1,4}, + rhs_contracting_dims={5,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + BatchDotSimplification pass; + ASSERT_FALSE(pass.Run(m.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 551ac4be73a7630d213a53ca3606aa7f890cd794..2caa979745b3b40817acb1b6951e1de5ffa294a4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -282,8 +283,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, s32_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); @@ -308,8 +312,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { HloInstruction* value = builder.AddInstruction( HloInstruction::CreateParameter(1, bf16_shape, "value")); - HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}), + {key, value}, 0, /*is_stable=*/false, &builder, + module.get())); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index e1b91b500191c7756f3d1a4b160a0dd1e09cfe7d..cbebbdc8a2d7d0b65f12accbe424bea383ff5355 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -191,6 +191,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. worklist.push_back(std::make_pair(subcomputation, diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 94af788c54f6c722997311bec50da3ed93aa3cee..98304757cae91d22466ed25f8c6e36ce90a848db 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -64,6 +64,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return CallContext::kParallel; default: diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index c02ffda575278905f6549b362e5e7d94f5713b36..57a636fd740995d6cce933fe19d5592a64bde5cf 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -30,7 +30,7 @@ namespace xla { // The context in which a computation is called by another computation. enum class CallContext { - // In a parallel contex the computation is applied to each element of the + // In a parallel context the computation is applied to each element of the // array argument(s). kMap and kReduce instructions call computations in // parallel context. kParallel, diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d4535b204d7f3ad8d4e24beea5d0dd79e7a15ab0..42672bc3875af2d732d80691df6bf85b9d8080cd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -136,6 +136,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl index e78330b21689fdd818cd97128bbcaaa9e0118602..ffa1cd4ec8e26e7dbe92e7b99cf65e99db5400b9 100644 --- a/tensorflow/compiler/xla/service/cpu/build_defs.bzl +++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl @@ -1,12 +1,11 @@ """build_defs for service/cpu.""" - def runtime_copts(): - """Returns copts used for CPU runtime libraries.""" - return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ - "//tensorflow:android_arm": ["-mfpu=neon"], - "//conditions:default": [] - }) + select({ - "//tensorflow:android": ["-O2"], - "//conditions:default": [] - })) + """Returns copts used for CPU runtime libraries.""" + return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ + "//tensorflow:android_arm": ["-mfpu=neon"], + "//conditions:default": [], + }) + select({ + "//tensorflow:android": ["-O2"], + "//conditions:default": [], + })) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index eafda68510d93ee54f2aead60a84f3e97b3fe1f4..19ab3bddb567afeeddb7c01b9a847b51bea5d957 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -95,6 +95,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" @@ -105,6 +106,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/dynamic_annotations.h" namespace xla { namespace cpu { @@ -255,6 +257,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); + pipeline.AddPass(); + // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); @@ -633,7 +637,13 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, +#ifdef MEMORY_SANITIZER + /*emit_code_for_msan=*/true +#else + /*emit_code_for_msan=*/false +#endif + ); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); @@ -670,9 +680,9 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -831,7 +841,9 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features); + &target_machine_features, + // TODO(b/66051036): Run full msan for AOT. + /*emit_code_for_msan=*/false); TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7fbe0fa157c57eb0c274662a1de95cf5328ccfa8..4ac61f44d9f38425da2d1fc6b9495cb4deba5047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 3361a5973f5e8c91802b26d68477347b196d3cac..fae9670051a654f38f09856368ffb700b0c7a085 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 0fecbaf391bc3122646af30b508fc1a88b6641e9..2bf22ec6e43ea9944935a4d0d5dcd22c5d190c17 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -963,8 +963,8 @@ Status EmitBatchDotOperation( KernelSupportLibrary ksl(b); return ksl.ForWithStatus( - "bdot", /*start=*/0, /*end=*/batch_count, /*step=*/1, - [&](llvm::Value* indvar) { + llvm_ir::IrName(&dot, "bdot"), /*start=*/0, /*end=*/batch_count, + /*step=*/1, [&](llvm::Value* indvar) { DotDimensionNumbers adjusted_dim_numbers = dot.dot_dimension_numbers(); adjusted_dim_numbers.clear_lhs_batch_dimensions(); adjusted_dim_numbers.clear_rhs_batch_dimensions(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index efdda8599a1a66a0b2e43d17cfb35e3514e905b0..2418d96440f9994842a54769cf6d561610ccfa18 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -77,7 +77,6 @@ namespace { using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -namespace gtl = tensorflow::gtl; } // namespace namespace cpu { @@ -87,7 +86,8 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features) + const TargetMachineFeatures* target_machine_features, + bool emit_code_for_msan) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -97,7 +97,8 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features) { + target_machine_features_(*target_machine_features), + emit_code_for_msan_(emit_code_for_msan) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_cpu_enable_fast_math())); @@ -517,6 +518,7 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { case U8: case S16: case U16: + case BF16: case F16: case S32: case U32: @@ -577,72 +579,14 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - llvm::FunctionType* less_than_type = llvm::FunctionType::get( - b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, - /*isVarArg=*/false); - auto less_than_function = llvm_ir::CreateFunction( - less_than_type, llvm::GlobalValue::InternalLinkage, - /*enable_fast_math=*/false, - /*optimize_for_size=*/true, absl::StrCat(IrName(sort), "_comparator"), - module_); - // Emit the code for the less_than function. - { - llvm::IRBuilder<>::InsertPointGuard guard(b_); - - auto* entry_bb = - llvm::BasicBlock::Create(b_.getContext(), "entry", less_than_function); - - b_.SetInsertPoint(entry_bb); - auto keys_ir_type = llvm_ir::PrimitiveTypeToIrType(keys_type, module_); - CHECK_EQ(less_than_function->arg_size(), 2); - llvm::Value* keys_lhs_ptr = less_than_function->arg_begin(); - keys_lhs_ptr = PointerCast(keys_lhs_ptr, keys_ir_type->getPointerTo()); - llvm::Value* keys_rhs_ptr = less_than_function->arg_begin() + 1; - keys_rhs_ptr = PointerCast(keys_rhs_ptr, keys_ir_type->getPointerTo()); - - // TODO(b/122298745): Replace the custom compare logic with a call to the - // computation specified for the Sort op. - llvm::Value* keys_lhs = Load(keys_ir_type, keys_lhs_ptr); - llvm::Value* keys_rhs = Load(keys_ir_type, keys_rhs_ptr); - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(keys_type)) { - // We would like a total order of floating point numbers so that the - // sort has a predictable behavior in the presence of NaNs. Rather - // than using floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the - // obvious order, -0 is ordered before 0, and -NaN and NaN appear at - // the beginning and end of the ordering. - auto k = b_.getInt(llvm::APInt::getSignedMaxValue( - keys_lhs->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b_.CreateSelect(b_.CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b_.CreateSub(k, v), v); - }; - keys_lhs = b_.CreateBitCast(keys_lhs, comparison_type); - keys_rhs = b_.CreateBitCast(keys_rhs, comparison_type); - keys_lhs = maybe_flip(keys_lhs); - keys_rhs = maybe_flip(keys_rhs); - } else if (!primitive_util::IsSignedIntegralType(keys_type)) { - is_signed_comparison = false; - } - llvm::Value* result = - b_.CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - keys_lhs, keys_rhs); - llvm::ReturnInst::Create(b_.getContext(), - /*retVal=*/result, entry_bb); - } - + auto less_than_function = FindOrDie(emitted_functions_, sort->to_apply()); + CHECK(absl::c_binary_search(thread_local_computations_, sort->to_apply())); llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), {b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), - b_.getInt32Ty()->getPointerTo(), less_than_function->getType()}, + b_.getInt32Ty()->getPointerTo(), b_.getInt1Ty(), b_.getInt8PtrTy(), + b_.getInt64Ty()->getPointerTo(), less_than_function->getType()}, /*isVarArg=*/false); auto* key_value_sort_func = llvm::dyn_cast( module_ @@ -673,7 +617,9 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { Call(key_value_sort_func, {b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), b_.getInt64(lower_dimensions), values, - b_.getInt32(sort->operand_count()), sizes, less_than_function}); + b_.getInt32(sort->operand_count()), sizes, + b_.getInt1(sort->is_stable()), GetExecutableRunOptionsArgument(), + GetProfileCountersArgument(), less_than_function}); if (sort->values_count() > 0) { llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, @@ -1914,7 +1860,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118333695): Support variadic reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on CPU"); } @@ -2281,6 +2227,25 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { InBoundsGEP(operands_alloca, {b_.getInt64(i)}); Store(operand_as_i8ptr, slot_in_operands_alloca); } + if (emit_code_for_msan_) { + // Mark the alloca as initialized for msan. The buffer gets read by the + // custom callee, which might be msan-instrumented. + // TODO(b/66051036): Run the msan instrumentation pass instead. + const llvm::DataLayout& dl = module_->getDataLayout(); + llvm::Type* intptr_type = b_.getIntPtrTy(dl); + auto* msan_unpoison_ir_function = llvm::cast( + module_ + ->getOrInsertFunction( + "__msan_unpoison", + llvm::FunctionType::get( + /*Result=*/b_.getVoidTy(), + /*Params=*/{i8_ptr_type, intptr_type}, /*isVarArg=*/false)) + .getCallee()); + Call(msan_unpoison_ir_function, + {PointerCast(operands_alloca, i8_ptr_type), + llvm::ConstantInt::get( + intptr_type, *operands_alloca->getAllocationSizeInBits(dl) / 8)}); + } auto* custom_call_ir_function = llvm::dyn_cast( module_ ->getOrInsertFunction( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 974dd7cd3f2254bfbc86fffae02c06c481af8902..0e372335f3aae919f9a9c559f86d4d61ab799b70 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -72,13 +72,15 @@ class IrEmitter : public DfsHloVisitorWithDefault, // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. + // emit_code_for_msan: whether emitted code should be compatible with msan. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine); + const TargetMachineFeatures* target_machine, + bool emit_code_for_msan); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -574,6 +576,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector thread_local_computations_; std::vector global_computations_; + bool emit_code_for_msan_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index a0667d0d9d1cde246f4b74626859955beeec08b0..70a6d0af02c0c2db7208db561cf29e35a74707b2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -32,8 +32,9 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes, - bool (*less_than)(char*, char*)) { + int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); @@ -54,6 +55,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 sort_dimension_offset = c; std::unique_ptr indices(new int64[sort_dimension_elements]); + std::unique_ptr comparison_values(new char*[2 * values_count]); std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); @@ -67,16 +69,27 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - std::stable_sort( - indices.get(), indices.get() + sort_dimension_elements, - [&](int64 a, int64 b) { - int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * - values_primitive_type_size_in_bytes[0]; - return less_than(values[0] + memory_index_lhs, - values[0] + memory_index_rhs); - }); + auto compare_function = [&](int64 a, int64 b) -> bool { + int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + values_primitive_type_size_in_bytes[0]; + for (int32 i = 0; i < values_count; ++i) { + comparison_values[i * 2] = values[i] + memory_index_lhs; + comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; + } + char result = 0; // Overwritten by less_than. + less_than(&result, run_options, comparison_values.get(), nullptr, + prof_counters); + return result != 0u; + }; + if (is_stable) { + std::stable_sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } else { + std::sort(indices.get(), indices.get() + sort_dimension_elements, + compare_function); + } // Reorder the values according to the order defined by 'indices'. for (int32 idx = 0; idx < values_count; ++idx) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 5460af3485b94aaef1a5822a79e4fa325bcb67ea..50c2911c3bd392b6df12717c34d250ce86ad26e0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -22,18 +22,25 @@ limitations under the License. extern "C" { // Each entry in 'values' represents a 3-dimensional shape with dimensions -// [a, b, c]. The 'b' dimension of the first shape is sorted into ascending -// order according to the results of comparisons using the provided 'less_than' +// [a, b, c]. The 'b' dimension of each shape is sorted into ascending order +// according to the results of comparisons using the provided 'less_than' // function. 'values_count' must be > 0 and specifies the number of entries in // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' -// bytes. The elements in each 'values' shape are reordered in the same way -// according to the comparisons using the first shape. +// bytes. 'is_stable' specifies whether the sorting should be stable. +// 'run_options' and 'prof_counters' are passed through to the less-than +// function, which expects the following arguments: +// - pointer to the return value buffer (char*) +// - xla::ExecutableRunOptions = 'run_options' (char*) +// - pointers to the parameter buffers (char**) +// - pointers to the buffer tables = nullptr for thread local functions (char**) +// - profile counters = 'prof_counters' (int64*) extern void __xla_cpu_runtime_KeyValueSort( tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, - tensorflow::int32* values_primitive_type_size_in_bytes, - bool (*less_than)(char*, char*)); + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, + void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 9c2685674fbc133de1220caef81ac3b60a1c0f7c..f7b64738b7b314b56f4ae60336d9c85c90287219 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -337,6 +337,11 @@ bool RegisterKnownJITSymbols() { reinterpret_cast(memset_pattern16)); #endif +#ifdef MEMORY_SANITIZER + registry->Register("__msan_unpoison", + reinterpret_cast(__msan_unpoison)); +#endif + return true; } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index 3934c03a04c978009282b3cd0d39bacf9b12a356..762ee67db9a1b2a753c6ec5538dee1d13282942e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -26,10 +26,16 @@ TEST_F(CpuKeyValueSortTest, SortR1) { const string hlo_text = R"( HloModule KeyValueSort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY main { a = f32[10] parameter(0) - ROOT result = f32[10] sort(f32[10] a), dimensions={0} + ROOT result = f32[10] sort(f32[10] a), dimensions={0}, to_apply=compare } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc index eb6c44b70ab34d0a294880b5de4fe0b3ba5e19e5..9fc472ff767441e60cf618ac9022e5c50ea20023 100644 --- a/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.cc @@ -938,6 +938,53 @@ void TiledSmallGemmEmitter::EmitTiledGemm( }); } +llvm::Type* GetPointerToElementType(llvm::Type* pointer_type) { + llvm::Type* type = + llvm::cast(pointer_type)->getElementType(); + while (auto* array_type = llvm::dyn_cast(type)) { + type = array_type->getElementType(); + } + + return type->getPointerTo(); +} + +struct GemvBuffersWithCanonicalType { + llvm::Value* lhs_canonicalized; + llvm::Value* rhs_canonicalized; + llvm::Value* addend_canonicalized; + llvm::Value* result_canonicalized; +}; + +GemvBuffersWithCanonicalType GetGemvBuffersWithCanonicalType( + llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* b) { + // We characterize a GEMV operation via M and K, since N is implicitly 1. + // This means the GEMV that multiplies (say) [5,6] with [6,1] is implemented + // by the same GEMV that multiplies [5,6] with [1,6]. However, the + // `llvm::Types` for the inputs to the two GEMVs don't match (in a trivial + // sense -- the in memory representations are the same) since they're computed + // from the `xla::Shape`s. Since we want to be able to call the same + // `llvm::Function` for the two GEMVs we canonicalize the types of the GEMV + // inputs here into the same type. + GemvBuffersWithCanonicalType buffers_with_canonical_type; + llvm::Type* lhs_type = lhs->getType(); + llvm::Type* rhs_type = rhs->getType(); + llvm::Type* addend_type = addend ? addend->getType() : nullptr; + llvm::Type* result_type = result->getType(); + + buffers_with_canonical_type.lhs_canonicalized = + b->CreateBitCast(lhs, GetPointerToElementType(lhs_type)); + buffers_with_canonical_type.rhs_canonicalized = + b->CreateBitCast(rhs, GetPointerToElementType(rhs_type)); + buffers_with_canonical_type.addend_canonicalized = + addend ? b->CreateBitCast(addend, GetPointerToElementType(addend_type)) + : nullptr; + buffers_with_canonical_type.result_canonicalized = + b->CreateBitCast(result, GetPointerToElementType(result_type)); + + return buffers_with_canonical_type; +} + } // namespace void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, @@ -950,12 +997,18 @@ void EmitRowMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { RowMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); @@ -972,12 +1025,18 @@ void EmitColumnMajorGemv(PrimitiveType scalar_type, int64 tile_rows, /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, /*k=*/k, /*has_addend=*/addend != nullptr); + GemvBuffersWithCanonicalType canonical_inputs = + GetGemvBuffersWithCanonicalType(lhs, rhs, addend, result, b); + KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), lhs, - rhs, addend, result, - [&](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result) { + /*optimize_for_size=*/optimize_for_size, b, config.GetCacheKey(), + canonical_inputs.lhs_canonicalized, canonical_inputs.rhs_canonicalized, + canonical_inputs.addend_canonicalized, + canonical_inputs.result_canonicalized, + [&config, b, &canonical_inputs](llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* addend, + llvm::Value* result) { ColumnMajorMatrixVectorProductEmitter emitter(config, lhs, rhs, addend, result, b); emitter.Emit(); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index c54e954c222ff7ca9c0739ec8a55b9d79b74a437..af039776b7f157a407f8f4cf3d7cabdbee45b014 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -105,6 +105,7 @@ class DfsHloVisitorBase { } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; + virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 33b2cc3fb098ec0d92f68756526fcc4a761d7149..341bb37b8355e9987a0331d0a66bb8fe87f019cf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -91,6 +91,9 @@ class DfsHloVisitorWithDefaultBase Status HandleFft(HloInstructionPtr fft) override { return DefaultAction(fft); } + Status HandleTriangularSolve(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index e8bc6d05716a2ef02e0280e86c7df4ac22fe78c4..559b9c1f2c9f341293ca89adc61e3312fd9f313c 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -297,7 +297,7 @@ StatusOr DotDecomposer::Run(HloModule* module) { const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); // A dot it not canonical if there are more than one contracting // dimension. - if (dnums.lhs_contracting_dimensions_size() > 1) { + if (dnums.lhs_contracting_dimensions_size() != 1) { non_canonical_dots.push_back(instruction); continue; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 727e0bfa52d45b6f8c67d7d04613e4865f18a53c..808929be75ec6fd0cfb15418a231431b8d51e089 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -440,14 +440,16 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( {operand_value}, {operand_value->getType()}, b_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = FCmpOEQ(operand_value, zero); - auto olt = FCmpOLT(operand_value, zero); - return Select(oeq, zero, - Select(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + auto ne0_i1 = FCmpONE(operand_value, zero); + auto ne0_float = UIToFP(ne0_i1, type); + llvm::Value* result = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {ne0_float, operand_value}, + {operand_value->getType()}, b_); + auto is_nan = FCmpUNO(operand_value, operand_value); + result = Select(is_nan, operand_value, result); + return result; } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if @@ -855,6 +857,9 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } +// TODO(b/123355973): We have an implementation of erfinv in math.cc. We +// shouldn't have two implementations, especially since this one isn't testable +// (it's only observable via a normally-distributed RNG). StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { if (prim_type != F16 && prim_type != F32 && prim_type != F64) { @@ -1362,26 +1367,69 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); llvm::Type* raw_value_ty = raw_value->getType(); - // Convert raw integer to float in range [0, 1) if the element is a float. + // If we're generating a floating-point value, convert the raw integer R (i.e. + // `raw_value`) to a float in the range [0, 1). + // + // The basic approach is to choose a significand and exponent such that the + // significand is uniformly distributed and the exponent is distributed, well, + // exponentially (it's more likely to be close to 0 than far from 0). + // + // An easy way to do this is to say that the significand is the first S bits + // of R, and the exponent is determined by the number of trailing zeroes in R, + // exp = 2^-(cttz(R) + 1). (+1 because the largest exponent should be -1; + // this way the largest value we can return is 1.999... * 2^-1 = 1-ε.) + // + // This results in a small bias. Namely, if R has enough trailing zeroes, the + // significand and exponent will "overlap". As a concrete example, consider + // + // 20 X's 12 zeroes + // R = 0bXXXXXXXXXXXXXXXXXXXX000000000000 + // + // Here the exponent is 2^-13 because R has 12 trailing zeroes. The + // significand is made up of the first 23 most-significant bits of R, which we + // observe contain 3 zeroes. This is biased because any random value with + // exponent 2^-12 will have a significand which ends in `000`. + // + // For f32s, this problem occurs only when there are more than 32-23 = 9 + // trailing zeros, which happens with probability 0.5^10 = ~0.1%. Moreover the + // probability of a large bias (i.e. many trailing 0s in the significand) is + // exponentially low. So we deem this acceptable. llvm::Value* elem_value = raw_value; if (elem_ir_ty->isFloatingPointTy()) { - unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); - CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); - // Perform the division using the float type with the same number of bits - // as the raw value to avoid overflow. - if (raw_value_size_in_bits == 32) { - elem_value = UIToFP(elem_value, b_->getFloatTy()); - elem_value = FDiv(elem_value, - llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); - } else { - elem_value = UIToFP(elem_value, b_->getDoubleTy()); - elem_value = FDiv( - elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); - } - - if (elem_ir_ty != elem_value->getType()) { - elem_value = FPTrunc(elem_value, elem_ir_ty); - } + const auto& dest_flt_semantics = elem_ir_ty->getFltSemantics(); + const int bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK_GE(bits, llvm::APFloat::semanticsSizeInBits(dest_flt_semantics)); + + // Subtract 1 because semanticsPrecision includes the "hidden bit", i.e. the + // implicit "1." at the beginning of the significand. + const int significand_bits = + llvm::APFloat::semanticsPrecision(dest_flt_semantics) - 1; + + llvm::Value* cttz = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cttz, {raw_value, /*is_zero_undef=*/b_->getFalse()}, + {raw_value->getType()}, b_); + llvm::Value* significand = LShr(raw_value, bits - significand_bits); + + // Exponent bias is -127 for f32, meaning that if the exponent is E and the + // significand is S, then the value of the number is 2^(E - 127) * (1.S). + // + // We want cttz == 0 to correspond to 2^-1, so our exponent is computed as + // E = 126 - cttz. + // + // For f64, this is all the same, except the bias is -1023. + // + // In IEEE floating point, the absolute value of the exponent bias equals + // the value of the largest possible exponent. + const int bias = -llvm::APFloat::semanticsMaxExponent(dest_flt_semantics); + llvm::Value* exponent = + Sub(llvm::ConstantInt::get(cttz->getType(), -bias - 1), cttz); + + // Now just slot everything into place! The `Trunc` is here because + // raw_value may be larger than our float destination. + elem_value = + BitCast(Trunc(Or(Shl(exponent, significand_bits), significand), + b_->getIntNTy(elem_ir_ty->getPrimitiveSizeInBits())), + elem_ir_ty); } // Convert the value for the requested distribution. @@ -1767,18 +1815,10 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( auto index_typed_const = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; - // TODO(b/118437727): Remove the R1 path. - llvm::Value* start_index_value; - if (hlo->operand(1)->shape().rank() == 1) { - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - } else { - llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - start_index_value, - operand_to_generator.at(hlo->operand(1 + i))(zero_index)); - } + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1 + i))(zero_index)); // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) @@ -1924,18 +1964,10 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( return llvm::ConstantInt::get(index_type, c); }; - llvm::Value* start_index_value; - // TODO(b/118437727): Remove the R1 path. - if (hlo->operand(2)->shape().rank() == 1) { - llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); - TF_ASSIGN_OR_RETURN(start_index_value, - operand_to_generator.at(hlo->operand(2))(dim_index)); - } else { - llvm_ir::IrArray::Index zero_index(index_type); - TF_ASSIGN_OR_RETURN( - start_index_value, - operand_to_generator.at(hlo->operand(2 + i))(zero_index)); - } + llvm_ir::IrArray::Index zero_index(index_type); + TF_ASSIGN_OR_RETURN( + llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(2 + i))(zero_index)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d3e2acaabd4f602171def70ccd3d4fd5adce0d0d..7d360fe38cfeda17878c363253c41883ec9fd64f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -216,8 +216,11 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator); + // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. + // + // Precondition: raw_value has at least as many bits as hlo's element type. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 10b8c01ff1383658fcfb2271c177ba54347f985a..1518d83083b3b0ce876da9344c483a23cd5b073c 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" - namespace xla { StatusOr> Executable::ExecuteOnStreams( @@ -173,11 +172,13 @@ Status Executable::DumpHloSnapshot() { } filename = SanitizeFileName(std::move(filename)); string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(hlo_session, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); + const size_t size = hlo_session.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK(tensorflow::SerializeToBufferDeterministic( + hlo_session, serialized.get(), size)); + return tensorflow::WriteStringToFile( + tensorflow::Env::Default(), file_path, + absl::string_view(serialized.get(), size)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 590942cddcdd138981ee829f090ae17b0d038e1a..a58ac39dffad56315308f784b08e6b6087b8e30a 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -296,7 +296,7 @@ static StatusOr PermuteBatchAndOffsetDims( // [3,1] out of operand into an accumulator of shape [4,3,1]. We then // reshape this result to [2,2,3] and finally transpose it to [2,3,2]. -StatusOr GatherExpander::ExpandGather( +StatusOr GatherExpander::ExpandInstruction( HloInstruction* gather_instr) { CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); @@ -361,25 +361,11 @@ StatusOr GatherExpander::ExpandGather( output_rank); } -StatusOr GatherExpander::Run(HloModule* module) { - auto is_nontrivial_gather = [](HloInstruction* inst) { - return inst->opcode() == HloOpcode::kGather && - // Avoid expanding gather ops that produce zero sized tensors, - // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); - }; - - std::vector gather_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(gather_instrs), is_nontrivial_gather); - } - - for (HloInstruction* inst : gather_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !gather_instrs.empty(); +bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::IsZeroElementArray(inst->shape()); } + } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 8af9c6b71fbc391bf7c0e9809e979b65135a6df3..5625a37cb46ca5b70f69d86bc424f6512bfb293f 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -16,20 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloModulePass { +class GatherExpander : public OpExpanderPass { public: absl::string_view name() const override { return "gather_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandGather(HloInstruction* gather_instr); + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* gather_inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 7d450f4b53cdea209f2ef10ba785be6ec3b8bf8d..cb43c27be961262bf29d4a3958de62cfada19aed 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 85fb2dd47abdad7073bf15a2f8b974a3ae0f01e4..25c4f70d89b4ebc483a61f1e28c7a55eb31f4bdf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -305,6 +305,7 @@ cc_library( "sequential_thunk.cc", "thunk.cc", "thunk_schedule.cc", + "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", ], @@ -324,6 +325,7 @@ cc_library( "sequential_thunk.h", "thunk.h", "thunk_schedule.h", + "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", ], @@ -364,6 +366,8 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "//tensorflow/stream_executor:blas", + "//tensorflow/stream_executor:device_memory", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -396,18 +400,21 @@ cc_library( srcs = ["cudnn_conv_algorithm_picker.cc"], hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ + ":autotuning_proto", ":backend_configs", ":buffer_comparator", ":cudnn_conv_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "//tensorflow/core:logger", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -758,6 +765,7 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:sort_simplifier", + "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -1081,3 +1089,12 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) + +xla_proto_library( + name = "autotuning_proto", + srcs = ["autotuning.proto"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/autotuning.proto b/tensorflow/compiler/xla/service/gpu/autotuning.proto new file mode 100644 index 0000000000000000000000000000000000000000..b4a08963b4f2ebc55c89ed57325093536f343bd1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/autotuning.proto @@ -0,0 +1,81 @@ +// This file defines protos that store the results of autotuning XLA:GPU +// operations. +// +// They are in proto format because we want to log them structured. They offer +// tremendous statistical, testing, and debugging value. +syntax = "proto3"; + +package xla.gpu; + +import "google/protobuf/duration.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; + +message CudnnVersion { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; +} + +message ComputeCapability { + int32 major = 1; + int32 minor = 2; +} + +message AutotuneResult { + message SuccessResult { + int64 scratch_bytes = 1; + google.protobuf.Duration run_time = 2; + } + + message ConvKey { + int64 algorithm = 1; + bool tensor_ops_enabled = 2; + } + + // If the conv runs successfully, success will be populated with the + // autotuning result. Otherwise, the error message is propagated. + oneof result { + SuccessResult success = 3; + string error_string = 4; + } + + oneof key { + ConvKey conv = 5; + } + + // Sometimes we run a correctness checker during autotuning. It compares the + // result buffer content between two algorithms, say, "reference" and "test" + // algorithms. The "test" algorithm is the one associated with this + // AutotuneResult. + // + // This field records the reference algorithm used. Notice that naming it + // "reference" doesn't mean it's always correct. However, empirically it's + // more correct, as it's "algo 0", less fancy than the compared one. + // + // Notice that the checker_failure may exist even in the success case. + // This is because the error string in `result` comes from the underlying + // implementation like cuDNN, which isn't aware that it produced an incorrect + // result. And even if the checker detects an incorrect result, we can still + // retrieve scratch_bytes and runtime_ms. + oneof checker_failure { + ConvKey reference_conv = 6; + } +} + +message AutotuneLog { + message Instruction { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; + } + + oneof instr_oneof { + Instruction instr = 1; + } + + // Records all auto-tuning results per algorithm. + repeated AutotuneResult results = 3; + + CudnnVersion cudnn_version = 4; + ComputeCapability compute_capability = 5; +} diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 309b0aca64954e64509d731dce28ce9d8da4ee43..603af5a654589e0b02c762b57d70a8b7628b1d0f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -16,14 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -32,7 +35,6 @@ namespace { using absl::optional; using se::DeviceMemoryBase; -using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; class ScratchAllocator : public se::ScratchAllocator { @@ -132,6 +134,31 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { return tensorflow::mutex_lock{it->second}; } +xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + xla::gpu::CudnnVersion cudnn_version; + if (auto* dnn = stream_executor->AsDnn()) { + StatusOr version_or = dnn->GetVersion(); + if (version_or.ok()) { + const auto& version = version_or.ValueOrDie(); + cudnn_version.set_major(version.major_version()); + cudnn_version.set_minor(version.minor_version()); + cudnn_version.set_patch(version.patch()); + } + } + return cudnn_version; +} + +xla::gpu::ComputeCapability GetComputeCapability( + se::StreamExecutor* stream_executor) { + xla::gpu::ComputeCapability cc; + int cc_major, cc_minor; + stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + cc.set_major(cc_major); + cc.set_minor(cc_minor); + return cc; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -145,8 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr -CudnnConvAlgorithmPicker::PickBestAlgorithm( +StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. @@ -233,8 +259,6 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); initialize_buffer(result_buffer); - se::dnn::ProfileResult best_result; - int64 best_result_bytes_used = 0; TF_ASSIGN_OR_RETURN(auto backend_config, instr->backend_config()); @@ -244,6 +268,7 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional first_algorithm; TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + std::vector profile_results; for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; @@ -254,73 +279,108 @@ CudnnConvAlgorithmPicker::PickBestAlgorithm( RunConvOptions options; options.profile_result = &profile_result; options.algo_override = alg; - bool launch_ok = + Status launch_status = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, options) - .ok(); - - if (launch_ok && profile_result.is_valid()) { - const bool crash_on_checking_failure = - instr->GetModule() - ->config() - .debug_options() - .xla_gpu_crash_on_verification_failures(); - if (comparator.has_value()) { - StatusOr result = comparator->CompareEqual( - se::DeviceMemory(result_buffer)); - if (!result.ok()) { - LOG(ERROR) << "Unable to compare " - << AlgorithmToString(*first_algorithm) << " against " - << AlgorithmToString(alg) << " for " << instr->ToString() - << ": " << result.status(); - CHECK(!crash_on_checking_failure); - } else if (!result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(*first_algorithm) << " vs " - << AlgorithmToString(alg); - CHECK(!crash_on_checking_failure); - } - } else if (cross_check_enabled) { - auto comp = F16BufferComparator::Create( - se::DeviceMemory(result_buffer), compiler_, allocator, - &stream); - if (comp.ok()) { - comparator.emplace(comp.ConsumeValueOrDie()); - first_algorithm.emplace(alg); - } else { - LOG(ERROR) << "Fail to initialize buffer comparator: " - << comp.status() << ", instruction: " << instr->ToString(); - CHECK(!crash_on_checking_failure); - } + &scratch_allocator, &stream, options); + + profile_results.emplace_back(); + AutotuneResult& result = profile_results.back(); + result.mutable_conv()->set_algorithm(alg.algo_id()); + result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); + + if (!launch_status.ok()) { + result.set_error_string(launch_status.error_message()); + continue; + } + + if (!profile_result.is_valid()) { + result.set_error_string("Invalid profile result"); + continue; + } + + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + result.mutable_success()->set_scratch_bytes(scratch_bytes_used); + *result.mutable_success()->mutable_run_time() = + protobuf_util::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + + if (comparator.has_value()) { + StatusOr compare_result = comparator->CompareEqual( + se::DeviceMemory(result_buffer)); + if (!compare_result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << compare_result.status(); + CHECK(!crash_on_checking_failure); + } else if (!compare_result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + auto* failure = result.mutable_reference_conv(); + failure->set_algorithm(first_algorithm->algo_id()); + failure->set_tensor_ops_enabled(first_algorithm->tensor_ops_enabled()); } - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) - << " succeeded, taking " << profile_result.elapsed_time_in_ms() - << "ms and using " << NumBytesToString(scratch_bytes_used) - << " of scratch (Best result: " - << best_result.elapsed_time_in_ms() << "ms, " - << NumBytesToString(best_result_bytes_used) << " of scratch)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - best_result_bytes_used = scratch_bytes_used; + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(result_buffer), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " << comp.status() + << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; } } - if (best_result.is_valid()) { - VLOG(2) << "Best algorithm for " << instr->ToString() << ": " - << AlgorithmToString(best_result.algorithm()) << ", takes " - << best_result.elapsed_time_in_ms() << "ms, and uses " - << best_result_bytes_used << "B of scratch memory."; - return AutotuneResult{best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used, - absl::Milliseconds(best_result.elapsed_time_in_ms())}; + + // Log the autotuning result. + { + AutotuneLog log; + *log.mutable_instr()->mutable_instruction() = instr->ToProto(); + for (const auto* op : instr->operands()) { + *log.mutable_instr()->add_operand_shapes() = op->shape().ToProto(); + } + for (const auto& profile : profile_results) { + *log.add_results() = profile; + } + *log.mutable_compute_capability() = GetComputeCapability(stream_exec_); + *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); + VLOG(2) << "Autotuning result:\n" << log.DebugString(); + tensorflow::Logger::Singleton()->LogProto(log); + } + + auto* profile_results_end = profile_results.data() + profile_results.size(); + + const AutotuneResult* best_result = std::min_element( + profile_results.data(), profile_results_end, + [](const AutotuneResult& lhs, const AutotuneResult& rhs) { + // The successful one should have a smaller key, since we are doing + // min_element. If they are both unsuccessful, keep the earlier one in + // the vector by comparing pointers. + return std::make_tuple( + !lhs.has_success(), + protobuf_util::FromDurationProto(lhs.success().run_time()), + &lhs) < std::make_tuple(!rhs.has_success(), + protobuf_util::FromDurationProto( + rhs.success().run_time()), + &rhs); + }); + + if (best_result != profile_results_end && best_result->has_success()) { + return *best_result; } return InternalError( @@ -341,22 +401,23 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( } auto best_algo = std::move(best_algo_or).ValueOrDie(); - VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm - << " and " << NumBytesToString(best_algo.scratch_bytes) + VLOG(1) << "Setting cudnn conv to use algorithm " + << best_algo.conv().algorithm() << " and " + << NumBytesToString(best_algo.success().scratch_bytes()) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled(); // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); Shape new_call_shape = ShapeUtil::MakeTupleShape( {instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); + ShapeUtil::MakeShape(U8, {best_algo.success().scratch_bytes()})}); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, instr->backend_config()); - backend_config.set_algorithm(best_algo.algorithm); - backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); + backend_config.set_algorithm(best_algo.conv().algorithm()); + backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled()); HloInstruction* new_call = computation->AddInstruction( instr->CloneWithNewOperands(new_call_shape, instr->operands())); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 4991db0948589e479a202f4082d96df275f6e088..2e34ba9672314a62290b8a557960a605a98996c7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -47,13 +48,6 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr Run(HloModule* module) override; private: - struct AutotuneResult { - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - absl::Duration runtime; - }; - StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); StatusOr PickBestAlgorithm( diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index ffd4214958275dc79bbcb060328893f8b68c737a..80ddb3e5cde708f535f71b75587171ed975f939c 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -448,7 +448,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return Load(accum_ptr); }; case HloOpcode::kReduce: - // TODO(b/112040122): This should be supported. + // TODO(b/118332391): This should be supported. CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 86c9bc6a345047fb5329af0be45c8981cc427f50..a7053e6a013be3ccf5725cbe003558be77104af1 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -428,7 +428,8 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, scratch_data = scratch_mem->device_memory(); } const MatrixDescriptor scratch_descriptor( - scratch_data, false, output_num_cols, output_num_rows, batch_size); + scratch_data, false, output_matrix.num_rows, output_matrix.num_cols, + batch_size); StatusOr best_algorithm = GetGemmAutotuneFn( element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd..9f0de3f794decb7b878b67c96030f8e11b0555fe 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -48,7 +48,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr); // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. -// This function works for both, sibling and producer-conumser multi-output +// This function works for both, sibling and producer-consumer multi-output // fusion. // So far, multi-output fusion is supported for loop fusions and reduce // input fusions only. It is up to the caller to ensure the instructions diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4268fb2c7a813b3b53e4cd48746028a7b369f28e..4765f67c4b17e97419182e341573f75ad3d6ac30 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 58bdd4209a2315cdb7d29e920faded4d1a6a5876..a6d80f0b6dddb3d8d0fd00c639e11c71da6a9f09 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -240,6 +240,32 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(keys_layout, *output_buffer)); } + } else if (instruction->opcode() == HloOpcode::kTriangularSolve) { + // TODO(phawkins): Ideally we would relax this constraint. What we + // actually want is that: + // a) the batch dimensions are major, in no particular order. + // b) the two minor dimensions are in fortran (column-major) order, + // although for the 'a' argument we could potentially accept row-major + // order and fold the transpose into the operator. + auto set_fortran_layout = [](Shape* shape) { + LayoutUtil::SetToDefaultLayout(shape); + int n = shape->mutable_layout()->minor_to_major_size(); + CHECK_GE(n, 2); + std::swap(shape->mutable_layout()->mutable_minor_to_major()->at(0), + shape->mutable_layout()->mutable_minor_to_major()->at(1)); + }; + Shape op0_shape = instruction->operand(0)->shape(); + Shape op1_shape = instruction->operand(1)->shape(); + Shape output_shape = instruction->shape(); + set_fortran_layout(&op0_shape); + set_fortran_layout(&op1_shape); + set_fortran_layout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 29756d27260b0f41b2dd4b649ea9b1610ff90268..391029e574622925b2a7e801a7d41d95e49a1cfb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -368,12 +368,21 @@ TEST_F(LayoutAssignmentTest, DotLayout) { TEST_F(LayoutAssignmentTest, SortLayout) { const char* hlo_text = R"( HloModule SortLayout + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + ENTRY sort { keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}}) values = f32[2,3]{1,0} parameter(0) transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), - dimensions={1} + dimensions={1}, to_apply=compare })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 8c6a6914792a96ab517fa5f20ff2215e4785490e..e593f535642e15f28a4a1c1f321881ba3c694548 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 82bdd677d96d3d0826bb4127b32d074eb632b1a3..3ed6553f9205803cfa17772b890c449cfb457c89 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 0007a9a8a3369d8ac010640127e1561615a6d813..8f010ab27a6c99b97e7808218de908ce558b0fe7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -492,8 +492,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); result = InsertValue(result, value.first, {0}); result = InsertValue(result, value.second, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { result = FMul(lhs_value, rhs_value); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + result = Mul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -583,9 +586,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_imag = Imag(accum, &b_); llvm::Value* imag_sum = FAdd(accum_imag, value.second); updated_accum = InsertValue(updated_accum, imag_sum, {1}); - } else { + } else if (ShapeUtil::ElementIsFloating(lhs_shape)) { llvm::Value* product = FMul(lhs_element, rhs_element); updated_accum = FAdd(accum, product); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsIntegral(lhs_shape)); + llvm::Value* product = Mul(lhs_element, rhs_element); + updated_accum = Add(accum, product); } Store(updated_accum, accum_address); @@ -647,7 +654,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118332391): Support variadic reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Variadic reduce is not supported on GPU"); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 294a454931b5cfa368bf094c428a1e942f4556b8..0cc65ebb52737aa9bb8866eb07278a2319aa797b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" @@ -60,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -487,6 +487,41 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { return Status::OK(); } +Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { + auto has_fortran_layout = [](const Layout& layout) { + int n = layout.minor_to_major_size(); + return layout.minor_to_major(0) == n - 2 && + layout.minor_to_major(1) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); + + std::vector> thunks; + + // Triangular solve is in-place on 'b', so copy 'b' to the output if they + // aren't the same buffer. + auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); + auto destination_buffer = GetAllocationSlice(*hlo); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); + } + + thunks.push_back(BuildTriangularSolveThunk(hlo)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), hlo)); + } + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { @@ -546,7 +581,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // a 1D array. The specialized version requires a initializer thunk that // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && root->shape().IsTuple()) { - // TODO(b/112040122): Support variadic reduce. + // TODO(b/118332391): Support variadic reduce. return Unimplemented("Variadic reduce is not supported on GPU"); } return EmitReductionToVector(fusion); @@ -635,7 +670,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { - // TODO(b/112040122): Support multi-output reduce. + // TODO(b/118332391): Support multi-output reduce. if (!reduce->shape().IsArray()) { return Unimplemented("Multi-output reduce is not supported on GPU"); } @@ -959,18 +994,18 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { BuildKernelThunk(scatter, /*implements_whole_instruction=*/thunks.empty())); - TF_RETURN_IF_ERROR( - EmitScatter(thunks.back().get(), scatter, - /*scatter_indices_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*scatter_indices, *scatter) - .EmitReadArrayElement(index, &b_, "scatter_index"); - }, - /*updates_gen=*/ - [=](const IrArray::Index& index) { - return GetIrArray(*updates, *scatter) - .EmitReadArrayElement(index, &b_, "update"); - })); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); // Elide the sequential thunk if there's no copy. if (thunks.size() == 1) { @@ -1118,17 +1153,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); int64 dimension_to_sort = sort->dimensions(0); - // In case there is a 'values' parameter that is a iota, we take note and use - // it later to ensure a stable sort. Otherwise, we don't guarantee a stable - // sort. - int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { - if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && - ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && - Cast(sort->operand(i))->iota_dimension() == - dimension_to_sort) { - iota_values_parameter_index = i; - } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -1241,25 +1266,23 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); - IrArray keys_array; std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); + values_arrays.reserve(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); - } + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); } return llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, - iota_values_parameter_index, IrName(sort), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, - kTileSize); + kTileSize, + [&](absl::Span operands, llvm::Value* output) { + return EmitCallToNestedComputation(*sort->to_apply(), operands, + output); + }); }; std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { @@ -1758,6 +1781,29 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } +std::unique_ptr IrEmitterUnnested::BuildTriangularSolveThunk( + const HloInstruction* inst) { + const HloInstruction* a = inst->operand(0); + const HloInstruction* b = inst->operand(1); + int64 m = b->shape().dimensions(b->shape().rank() - 2); + int64 n = b->shape().dimensions(b->shape().rank() - 1); + int64 batch_size = std::accumulate( + b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, + int64{1}, [](int64 a, int64 b) { return a * b; }); + int64 elem_size = + ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); + int64 a_batch_stride = inst->triangular_solve_options().left_side() + ? m * m * elem_size + : n * n * elem_size; + int64 b_batch_stride = m * n * elem_size; + return absl::make_unique( + inst->triangular_solve_options(), + /*a_input_buffer=*/GetAllocationSlice(*a), + /*b_input_buffer=*/GetAllocationSlice(*inst), + inst->shape().element_type(), batch_size, m, n, a_batch_stride, + b_batch_stride, inst); +} + StatusOr> IrEmitterUnnested::BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); @@ -2133,7 +2179,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } - int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, @@ -2967,12 +3012,11 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel // *anyway*. if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { - KernelSupportLibrary{&b_}.If( - "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), - ConstructIrArrayForOutputs(*unnested_hlo), &b_, - module_); - }); + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + }); } // For each tiled parameter, cast its input IrArray to the corresponding diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 21b842bb2cd63ac454f85556df20ae5877cecbe1..f85e18bbf0798ef3d5b87e81d287d8aed691dfc4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -176,6 +176,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* after_all) override; @@ -319,6 +320,9 @@ class IrEmitterUnnested : public IrEmitter { // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); + // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. + std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 1f4f1766618c71c9ef8705f3038676a0518b3ddd..6e00e4b4ff8c493f00fae3355215fb13fb5f4f10 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -118,6 +119,9 @@ std::vector GetCudaRootCandidates( const HloModuleConfig& hlo_module_config) { std::vector potential_cuda_roots = tensorflow::CandidateCudaRoots(); + // "." is our last resort, even though it probably won't work. + potential_cuda_roots.push_back("."); + // CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has // highest priority. string xla_gpu_cuda_data_dir = @@ -129,9 +133,23 @@ std::vector GetCudaRootCandidates( return potential_cuda_roots; } +void PrintCantFindCudaMessage(absl::string_view msg, + const HloModuleConfig& hlo_module_config) { + LOG(WARNING) << msg; + LOG(WARNING) << "Searched in the following directories:"; + for (const auto& dir : GetCudaRootCandidates(hlo_module_config)) { + LOG(WARNING) << " " << dir; + } + LOG(WARNING) + << "You can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."; +} + // Returns the directory containing nvvm libdevice files. string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { - for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) { + const auto& candidate_dirs = GetCudaRootCandidates(hlo_module_config); + for (const string& cuda_root : candidate_dirs) { string libdevice_dir = tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); VLOG(2) << "Looking for libdevice at " << libdevice_dir; @@ -140,8 +158,14 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { return libdevice_dir; } } - LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; - // Last resort: maybe in the current folder. + PrintCantFindCudaMessage( + "Can't find directory containing CUDA libevice. This may result in " + "compilation or runtime failures, if the program we try to run uses " + "routines from libdevice.", + hlo_module_config); + + // GetCudaRotCandidates always inclues ".", but but if everything fails, we + // return it anyway. Better than returning the empty string. return "."; } @@ -172,6 +196,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); // Convert BF16 operations to F32 operations so that the GPU backend can // support BF16 operations without directly implementing a BF16 lowering for // most ops. @@ -772,14 +798,19 @@ StatusOr> NVPTXCompiler::RunBackend( std::unique_ptr profile_index_map; std::unique_ptr profile_printer; - if (module->config().hlo_profiling_enabled()) { + if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = absl::make_unique(*module); - profile_printer = CreateHloProfilePrinterData( - *profile_index_map, cost_analysis, entry_computation->name()); + VLOG(1) << "HLO memory read+written: " + << tensorflow::strings::HumanReadableNumBytes( + cost_analysis.bytes_accessed()); + if (module->config().hlo_profiling_enabled()) { + profile_index_map = absl::make_unique(*module); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); + } } auto* gpu_executable = new GpuExecutable( @@ -843,10 +874,11 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( log_warning = !warning_done.exchange(true); } if (log_warning) { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + PrintCantFindCudaMessage( + "Can't find ptxas binary. Will back to the GPU driver " + "for PTX -> sass compilation. This is OK so long as you don't " + "see a warning below about an out-of-date driver version.", + hlo_module_config); } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 8154d75d23a6d49153ccb6824402aff73f365617..cb012649200c6386d3ae25d088aa3b16bd40be82 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index a1ed8499040359fe7265a7317b0577a990a2234c..d33e9cf714ee3810b1fb2fa8c05c3ed399d27bfb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index f91a22d482bc8bc046977870a7a4d18ca1acde68..06b06a5b1ee1fb9996be3ebe326893c4160a7e29 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc index c78605cebbc671272b8df9faf0e0cc54be2f5b1c..a677617727c04811584cbaa295d164ed27273bb2 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -48,6 +48,8 @@ std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { return os << "kOutfeed"; case Thunk::kSequential: return os << "kSequential"; + case Thunk::kTriangularSolve: + return os << "kTriangularSolve"; case Thunk::kTuple: return os << "kTuple"; case Thunk::kWhile: diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index e68bee035a029178844282995429eaa960cc4817..bc69af897a01775d2d33d46067464b10e049f3e1 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -56,6 +56,7 @@ class Thunk { kMemzero, kOutfeed, kSequential, + kTriangularSolve, kTuple, kWhile, }; diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..5200a2af412979c7e38d95c5a9bd5bc2ab64f086 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc @@ -0,0 +1,149 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/device_memory.h" + +namespace xla { +namespace gpu { + +TriangularSolveThunk::TriangularSolveThunk( + const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, PrimitiveType type, + int64 batch_size, int64 m, int64 n, int64 a_batch_stride, + int64 b_batch_stride, const HloInstruction* hlo) + : Thunk(Kind::kTriangularSolve, hlo), + uplo_(options.lower() ? se::blas::UpperLower::kLower + : se::blas::UpperLower::kUpper), + side_(options.left_side() ? se::blas::Side::kLeft + : se::blas::Side::kRight), + unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit + : se::blas::Diagonal::kNonUnit), + a_buffer_(a_buffer), + b_buffer_(b_buffer), + type_(type), + batch_size_(batch_size), + m_(m), + n_(n), + a_batch_stride_(a_batch_stride), + b_batch_stride_(b_batch_stride) { + transpose_a_ = [&] { + switch (options.transpose_a()) { + case TriangularSolveOptions::NO_TRANSPOSE: + return se::blas::Transpose::kNoTranspose; + case TriangularSolveOptions::TRANSPOSE: + return se::blas::Transpose::kTranspose; + case TriangularSolveOptions::ADJOINT: + return se::blas::Transpose::kConjugateTranspose; + default: + LOG(ERROR) << "Invalid triangular solve transpose value " + << options.transpose_a(); + return se::blas::Transpose::kNoTranspose; + } + }(); +} + +Status TriangularSolveThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_) + << " side=" << se::blas::SideString(side_) + << " diagonal=" << se::blas::DiagonalString(unit_diagonal_) + << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_ + << " a_batch_stride=" << a_batch_stride_ + << " b_batch_stride=" << b_batch_stride_; + + const int lda = side_ == se::blas::Side::kLeft ? m_ : n_; + const int ldb = m_; + + char* a_base = static_cast( + buffer_allocations.GetDeviceAddress(a_buffer_).opaque()); + char* b_base = static_cast( + buffer_allocations.GetDeviceAddress(b_buffer_).opaque()); + for (int64 i = 0; i < batch_size_; ++i) { + bool launch_ok; + se::DeviceMemoryBase a_data = + se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_); + se::DeviceMemoryBase b_data = + se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_); + switch (type_) { + case F32: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0f, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case F64: { + se::DeviceMemory b_data_typed(b_data); + launch_ok = stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, + unit_diagonal_, m_, n_, /*alpha=*/1.0, + se::DeviceMemory(a_data), lda, + &b_data_typed, ldb) + .ok(); + break; + } + case C64: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0f, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + case C128: { + se::DeviceMemory> b_data_typed(b_data); + launch_ok = + stream + ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_, + n_, /*alpha=*/1.0, + se::DeviceMemory>(a_data), + lda, &b_data_typed, ldb) + .ok(); + break; + } + default: + return InvalidArgument("Invalid type for triangular solve %d", type_); + } + if (!launch_ok) { + return InternalError("Unable to launch triangular solve for thunk %p", + this); + } + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..c947162ea32f197f808d099859eadbbc55a65ab1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/blas.h" + +namespace xla { +namespace gpu { + +// This class stores everything that StreamExecutor needs to launch a triangular +// solve (BlasTrsm). It is generated by IrEmitter. +// +// Thread-compatible. +class TriangularSolveThunk : public Thunk { + public: + TriangularSolveThunk(const TriangularSolveOptions& options, + const BufferAllocation::Slice& a_buffer, + const BufferAllocation::Slice& b_buffer, + PrimitiveType type, int64 batch_size, int64 m, int64 n, + int64 a_batch_stride, int64 b_batch_stride, + const HloInstruction* hlo); + + TriangularSolveThunk(const TriangularSolveThunk&) = delete; + TriangularSolveThunk& operator=(const TriangularSolveThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const se::blas::UpperLower uplo_; + const se::blas::Side side_; + const se::blas::Diagonal unit_diagonal_; + se::blas::Transpose transpose_a_; + + const BufferAllocation::Slice a_buffer_; + const BufferAllocation::Slice b_buffer_; + + const PrimitiveType type_; + const int64 batch_size_; + const int64 m_; + const int64 n_; + const int64 a_batch_stride_; + const int64 b_batch_stride_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRIANGULAR_SOLVE_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc index c552c2925497f1c4808d74a615d35cdbeeba1858..bbbcc2dbb0f71d08462a1aad6d97e7fd07b2a1fb 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 263b42a29dbb0dbc0fb6eca7968674ff242f45ed..ae9e3169fd9b7a4655ab91ffb1589b845402ba8d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 59 +// Next ID: 62 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -175,6 +175,9 @@ message HloInstructionProto { // partners. bool is_host_transfer = 47; + // Whether this Sort instruction should be stable. + bool is_stable = 60; + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. @@ -193,6 +196,12 @@ message HloInstructionProto { // operand. bool constrain_layout = 56; repeated xla.ShapeProto operand_shapes_with_layout = 57; + + // Options for TriangularSolve + xla.TriangularSolveOptions triangular_solve_options = 59; + + // Describes how parameters behave with regards to replicas. + xla.ParameterReplication parameter_replication = 61; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 40fe91398be33f5681e1389e1b6fadcbd87487bb..817e15f9ff10a9b7e1a502265c85f70fdd681dd9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -296,7 +296,7 @@ void ComputeComputationPostOrder(HloComputation* computation, } // namespace void HloComputation::ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_group, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const { std::vector dfs_stack; @@ -320,66 +320,75 @@ void HloComputation::ComputeInstructionPostOrder( visited->insert({current, kVisiting}); - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for things like HLO stringification. - const auto& operands = current->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i]); - } - - for (HloInstruction* op : current->control_predecessors()) { - dfs_stack.emplace_back(op); - } - - // Add inputs for send->recv_done dependencies and all-reduce - // dependencies. - switch (current->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(current->channel_id()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } - break; + const auto get_channel_id = + [](HloInstruction* inst) -> absl::optional { + switch (inst->opcode()) { + case HloOpcode::kRecvDone: + return inst->channel_id(); + case HloOpcode::kAllReduce: + return inst->all_reduce_id(); + default: + return absl::nullopt; } - case HloOpcode::kAllReduce: { - auto all_reduce_id = current->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - for (HloInstruction* op : it->second) { - dfs_stack.emplace_back(op); - } - } + }; + + // When adding a predecessor to the dfs_stack, we need to also add its + // associated channel dependencies. + const auto add_dfs_stack = [&](HloInstruction* inst) { + auto channel_id = get_channel_id(inst); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + dfs_stack.emplace_back(cinst); } - break; + } else { + dfs_stack.emplace_back(inst); } - default: - break; + }; + + const auto add_predecessors = [&](HloInstruction* inst) { + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for things like HLO stringification. + const auto& operands = inst->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + add_dfs_stack(operands[i]); + } + + for (HloInstruction* op : inst->control_predecessors()) { + add_dfs_stack(op); + } + }; + + // If the current instruction is a channel instruction, add the dependencies + // from all associated instructions of the channel. + auto channel_id = get_channel_id(current); + if (channel_id && channel_dependency_group.count(*channel_id)) { + auto it = channel_dependency_group.find(*channel_id); + for (HloInstruction* cinst : it->second) { + add_predecessors(cinst); + } + } else { + add_predecessors(current); } } } -HloComputation::ChannelDependencyMap +HloComputation::ChannelDependencyGroup HloComputation::ComputeChannelDependencies() const { - ChannelDependencyMap channel_dependency_map; + ChannelDependencyGroup channel_dependency_group; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { - case HloOpcode::kSend: { - channel_dependency_map[instruction->channel_id()].push_back( + case HloOpcode::kSend: + case HloOpcode::kRecvDone: + channel_dependency_group[instruction->channel_id()].push_back( instruction.get()); break; - } case HloOpcode::kAllReduce: { auto all_reduce_id = instruction->all_reduce_id(); if (all_reduce_id) { - auto& dependencies = channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(instruction->operands(), - std::back_inserter(dependencies)); - absl::c_copy(instruction->control_predecessors(), - std::back_inserter(dependencies)); + channel_dependency_group[all_reduce_id.value()].push_back( + instruction.get()); } break; } @@ -387,11 +396,11 @@ HloComputation::ComputeChannelDependencies() const { break; } } - return channel_dependency_map; + return channel_dependency_group; } std::vector HloComputation::MakeInstructionPostOrder() const { - auto channel_dependency_map = ComputeChannelDependencies(); + auto channel_dependency_group = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; @@ -404,7 +413,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(channel_dependency_map, &post_order, + ComputeInstructionPostOrder(channel_dependency_group, &post_order, instruction.get(), &visited); } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0cb9caddd089011f3e9a4473995847dc966dd402..212dfa15a13185f1050103739fad8b560270d401 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -369,13 +369,13 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // all-reduce the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = + // Returns a map from channel-id to the group of instructions associated with + // the channel. These instructions will be considered as a single node for + // dependency purposes. Send and RecvDone are in the group, and AllReduces + // with the same channel id are in the group. + using ChannelDependencyGroup = absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; + ChannelDependencyGroup ComputeChannelDependencies() const; // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. @@ -391,6 +391,10 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // Clear the unique ID of the computation so that it can be re-assigned, such + // as for the purpose of compacting the unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // The id of this computation should be unique within the module. void SetUniqueId(int64 id) { CHECK_EQ(unique_id_, -1); @@ -434,7 +438,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyMap& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_map, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 3b88e9745c27d6e1f2a46e5c83ac2e8bd8d05150..fe37ca6b3963430c765f27aede4f506366fc5d97 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -24,7 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,6 +39,7 @@ namespace xla { namespace { namespace m = match; +namespace op = xla::testing::opcode_matchers; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -668,5 +671,34 @@ TEST_F(HloComputationTest, DeepEquality) { EXPECT_FALSE(*computation_c == *computation_b); } +// Tests that cross-module AllReduce instructions are ordered before all their +// predecessors and after all their successors. +TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) { + const char* const hlo_string = R"( +HloModule Module + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY entry { + param = f32[128] parameter(0), sharding={maximal device=0} + crs0 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=0} + crs1 = f32[128] all-reduce(param), + replica_groups={{0}}, all_reduce_id=1, barrier="", to_apply=add, + sharding={maximal device=1} + add = f32[128] add(crs0, crs0), sharding={maximal device=0} + ROOT t = (f32[128], f32[128]) tuple(add, crs1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(), + ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(), + op::Add(), op::Tuple())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1ee958114ebfa976cea72e901432575b7dc58321..29ac263c5f39b5ec1f02a232704adcd3e3f21f60 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -245,7 +245,7 @@ Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { for (auto dim : dnums.lhs_contracting_dimensions()) { reduction_width *= lhs_shape.dimensions(dim); } - // Each output elment requires reduction_widht FMA operations. + // Each output elment requires reduction_width FMA operations. current_properties_[kFlopsKey] = kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width; return Status::OK(); @@ -546,6 +546,21 @@ Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { return Status::OK(); } +Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) { + float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; + bytes_accessed += GetShapeSize(hlo->operand(1)->shape()); + current_properties_[kBytesAccessedKey] = bytes_accessed; + + const Shape& a_shape = hlo->operand(0)->shape(); + const Shape& b_shape = hlo->operand(1)->shape(); + // Estimate as batch * mn^2 / 2 flops. + int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); + elems *= ShapeUtil::ElementsIn(b_shape); + // Each output elment requires reduction_widht FMA operations. + current_properties_[kFlopsKey] = kFmaFlops * elems; + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 421786e20a3d9528ea76a44b3087ab2aed81d2b5..96357dec68e390251c43c2c3fc6f5a5612063fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; + Status HandleTriangularSolve(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index d56f673455f9129b72e9d85eaf8cbf03cfee4302..b5d9e8e7f1a703d5d914a12d5226d53821071be6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -17,10 +17,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -268,6 +273,29 @@ StatusOr MakeSelectHlo(HloInstruction* pred, select_shape, HloOpcode::kSelect, pred, on_true, on_false)); } +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module) { + CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; + HloComputation* compare_computation; + XlaBuilder b("Sort.Compare"); + std::vector operand_types(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { + operand_types[i] = operands[i]->shape().element_type(); + } + XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module); + compare_computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + return builder->AddInstruction(HloInstruction::CreateSort( + sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 1c3174e9c89c16cb11589e7c0235bdf13eae6b85..17b7a2da6a9da994ea2d496b549eec79278b56b5 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -123,6 +123,15 @@ StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false); +// Creates a Sort HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. Also creates a +// default compare sub-computation which sorts the first operand into ascending +// order. 'is_stable' specifies whether the sorting should be stable. +StatusOr MakeSortHlo( + const Shape& sort_shape, absl::Span operands, + int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder, + HloModule* module); + // Creates an R1 Constant HLO instruction of the given PrimitiveType with the // given values and adds it to the given computation. template diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index e602107cbe64320a8e8e740168cb294ec6be9667..849cac278ee379122ba1ff9fade3bf003969b8a7 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4a7c4963b7b399e625da907b3810c42df7ee2bd3..768e3afb3b80698061b62c4aadef09c20e2f286c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -2356,14 +2357,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -2371,6 +2375,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -2378,11 +2383,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 9b0f2b2a0f4dd5d1d1191e9ab0637cc3034b50da..7d6b86056af3fc2128fe1642bbfa0ca6f9ef1da0 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -127,6 +127,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops where it does not make sense to convert them. if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kBitcastConvert || opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { continue; @@ -145,7 +146,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kConditional) { + opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) { continue; } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5b633784e2f306290ca6c096f67c657be1f188c8..4171f738620dbf545e5883b8c26169fae4b93643 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -176,5 +176,19 @@ ENTRY main { EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); } +TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) { + const string& hlo_string = R"( + HloModule test + + ENTRY test { + p = bf16[] parameter(0) + ROOT c = u16[] bitcast-convert(p) + })"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + HloElementTypeConverter converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get())); + EXPECT_FALSE(converted); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 56a1b6f43945adae18313546432b959f66a32dcf..4d6487700b24cfd3b89aece58e5ad6d7bb43a800 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -1461,14 +1462,6 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { } Shape key_shape = sort->operand(0)->shape(); auto rank = key_shape.rank(); - PrimitiveType keys_type = key_shape.element_type(); - if (keys_type != F64 && keys_type != U64 && keys_type != S64 && - keys_type != F32 && keys_type != U32 && keys_type != S32 && - keys_type != BF16 && keys_type != F16 && keys_type != U16 && - keys_type != S16 && keys_type != U8 && keys_type != S8) { - return InvalidArgument("Unsupported type for Sort: %s", - PrimitiveType_Name(keys_type)); - } std::vector result_literals; result_literals.reserve(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { @@ -1479,6 +1472,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { int64 sort_dim = sort->dimensions(0); int64 sort_dim_elements = key_shape.dimensions(sort_dim); increment[sort_dim] = sort_dim_elements; + HloEvaluator embedded_evaluator(max_loop_iterations_); // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment, @@ -1499,78 +1493,51 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { } std::vector indices_to_sort(sort_dim_elements); std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0); - std::stable_sort( - indices_to_sort.begin(), indices_to_sort.end(), - [keys_type, &literals_to_sort](int64 a, int64 b) { - switch (keys_type) { - case F64: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case U64: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case S64: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case F32: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case U32: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case S32: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case BF16: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case F16: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case U16: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case S16: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case U8: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - case S8: { - auto key_lhs = literals_to_sort[0].Get({a}); - auto key_rhs = literals_to_sort[0].Get({b}); - return SafeLess(key_lhs, key_rhs); - } - default: - // We should never reach here, because we checked earlier - // that 'key_type' is one of the cases above. - LOG(FATAL) << "Invalid key type in Sort: %s", - PrimitiveType_Name(keys_type); - return false; - } - }); + Status compare_status = Status::OK(); + auto comparator = [sort, &compare_status, &embedded_evaluator, + &literals_to_sort](int64 a, int64 b) { + std::vector literals; + literals.reserve(2 * sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto lhs = ExtractFromIndexPositions(literals_to_sort[i], {a}, + /*extract_as_scalar=*/true); + if (!lhs.ok()) { + compare_status = lhs.status(); + return false; + } + literals.push_back(std::move(lhs.ValueOrDie())); + auto rhs = ExtractFromIndexPositions(literals_to_sort[i], {b}, + /*extract_as_scalar=*/true); + if (!rhs.ok()) { + compare_status = rhs.status(); + return false; + } + literals.push_back(std::move(rhs.ValueOrDie())); + } + std::vector literal_ptrs; + absl::c_transform(literals, std::back_inserter(literal_ptrs), + [](const Literal& literal) { return &literal; }); + + auto computed_result = + embedded_evaluator.Evaluate(*sort->to_apply(), literal_ptrs); + // Clear visit states so that we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + if (!computed_result.ok()) { + compare_status = computed_result.status(); + return false; + } + return computed_result.ValueOrDie().Get({}); + }; + if (Cast(sort)->is_stable()) { + std::stable_sort(indices_to_sort.begin(), indices_to_sort.end(), + comparator); + } else { + std::sort(indices_to_sort.begin(), indices_to_sort.end(), comparator); + } + if (!compare_status.ok()) { + return compare_status; + } std::vector slice_dimensions(rank, 1); slice_dimensions[sort_dim] = sort_dim_elements; std::vector start_indices(rank, 0); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 72ea40bcd797def3bc0765986881792b8752d9e1..357975a131d0c7e63c06e96852468b43d97a37f2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -331,16 +331,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { std::vector arg_literals_; // Max loop iterations to execute with no maximum if negative. - int64 max_loop_iterations_; + int64 max_loop_iterations_ = 0; // Module-level seed handle. - uint64 seed_; + uint64 seed_ = 0; // RNG engine. std::minstd_rand0 engine_; // DynamicDimensionInference is used to evaluate GetDimensionSize, which // returns the dynamic dimension size of its operand. - DynamicDimensionInference* dynamic_dimension_inference_; + DynamicDimensionInference* dynamic_dimension_inference_ = nullptr; // Optional handler for custom_call ops. std::function(HloInstruction* custom_call, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index fb8cd299cef06d549130cd56dd2c430c4c1a0387..383921fde22242b6ede95a6554f2348ab6fd4277 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -111,6 +111,24 @@ class HloEvaluatorTest : public HloTestBase { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } + void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0, + Literal src1, Literal src2) { + HloComputation::Builder b(TestName()); + auto operand0 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src0))); + auto operand1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src1))); + auto operand2 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(src2))); + b.AddInstruction(HloInstruction::CreateTernary( + expected.shape(), opcode, operand0, operand1, operand2)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); + } + protected: explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {} HloEvaluator evaluator_; @@ -152,6 +170,33 @@ TEST_P(HloEvaluatorBf16Test, DoesClamp) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +// Verifies that clamping of int64 does not cause loss of precision +TEST_P(HloEvaluatorBf16Test, DoesClampInt64) { + auto ones = [](int bits) { return (int64{1} << bits) - 1; }; + + auto low = + LiteralUtil::CreateR2({{0, ones(54)}, {ones(54), ones(58)}}); + auto value = LiteralUtil::CreateR2({{0, ones(56)}, {0, ones(58)}}); + auto high = LiteralUtil::CreateR2( + {{ones(54), ones(55)}, {ones(56), ones(58)}}); + + Shape shape = low.shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = + LiteralUtil::CreateR2({{0, ones(55)}, {ones(54), ones(58)}}); + + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) { auto low = LiteralUtil::CreateR0(0.f); auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); @@ -254,6 +299,20 @@ TEST_F(HloEvaluatorTest, DoesDivideInt64) { TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } + +TEST_F(HloEvaluatorTest, DoesClampS64) { + auto low = LiteralUtil::CreateR1( + {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL}); + auto value = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL}); + auto high = LiteralUtil::CreateR1( + {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL}); + auto expected = LiteralUtil::CreateR1( + {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL}); + TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low), + std::move(value), std::move(high)); +} + TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) { auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 742a389ed04eb7303197467587223486c780a31e..d516a6258c80bda168ef4c6fd976e60946eb8b5b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #include +#include #include "absl/algorithm/container.h" #include "absl/base/casts.h" @@ -43,46 +44,6 @@ template using is_complex_t = absl::disjunction, std::is_same>; -// It's UB to use std::sort with std::less, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. -template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return a < b; -} - -template ::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - bool lhs_is_negative = std::signbit(a); - bool rhs_is_negative = std::signbit(b); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(a); - bool rhs_nan = std::isnan(b); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; - } - return a < b; -} - -template ::value || - std::is_same::value>::type* = nullptr> -bool SafeLess(const NativeT& a, const NativeT& b) { - return SafeLess(static_cast(a), static_cast(b)); -} - // ToArithmeticSafeType(T t): // - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed // integer, and @@ -462,9 +423,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleNegate(negate); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { @@ -474,6 +435,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value || + std::is_same::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return std::isnan(elem_operand) + ? elem_operand + : std::copysign( + elem_operand != ElementwiseT(0), + elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -916,9 +894,29 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleShiftRightLogical(shrl); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + // Special case for integral type due to MSVC's std::isnan being unable to + // handle integral type. + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return static_cast( + std::min(high, std::max(value, low))); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template ::value && + !std::is_integral::value>::type* = + nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { @@ -926,7 +924,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(NAN); } return static_cast( - std::fmin(high, std::fmax(value, low))); + std::min(high, std::max(value, low))); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -2693,12 +2691,25 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& high = parent_->GetEvaluatedLiteralFor(random->operand(1)); - std::uniform_real_distribution generator( - low.Get({}), high.Get({})); - + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + auto low_val = low.Get({}); + auto high_val = high.Get({}); + std::uniform_real_distribution generator(low_val, high_val); TF_RETURN_IF_ERROR( result.Populate([&](absl::Span /*indexes*/) { - return generator(parent_->engine_); + while (true) { + NativeT v = generator(parent_->engine_); + if (v != high_val) { + return v; + } + } })); break; } @@ -2832,21 +2843,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { absl::Span start_indices, const Shape& result_shape) { std::vector start; - // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish - // between the cases, this currently assumes there is at least 1 index. That - // is wrong in the general case, because for scalar indices, if the operand - // is scalar, then there are no indices. This problem with resolve itself. - const HloInstruction* first_index = start_indices[0]; - if (first_index->shape().rank() == 1) { - auto start_indices_typed = - parent_->GetEvaluatedLiteralFor(first_index).data(); - start = std::vector(start_indices_typed.begin(), - start_indices_typed.end()); - } else { - for (HloInstruction* index : start_indices) { - start.push_back( - parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); - } + + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); } // Clamp the start indices so the slice is in-bounds w.r.t the operand. @@ -2879,22 +2879,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result = operand_literal.Clone(); const auto rank = result.shape().rank(); std::vector start; - // TODO(b/118437727): Remove the R1 code-path. Note that to distinguish - // between the cases, this currently assumes there is at least 1 index. That - // is wrong in the general case, because for scalar indices, if the operand - // is scalar, then there are no indices. This problem with resolve itself. - const HloInstruction* first_index = start_indices[0]; - if (first_index->shape().rank() == 1) { - auto start_indices_typed = - parent_->GetEvaluatedLiteralFor(first_index).data(); - start = std::vector(start_indices_typed.begin(), - start_indices_typed.end()); - } else { - for (HloInstruction* index : start_indices) { - start.push_back( - parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); - } + for (HloInstruction* index : start_indices) { + start.push_back( + parent_->GetEvaluatedLiteralFor(index).GetFirstElement()); } + // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. for (int64 i = 0; i < rank; ++i) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 46ee99923ee9b6d852e6190cc8de6afe0b99457e..49300b3ffe2f755d103af7877ab3fee5298eeb3e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -536,7 +535,12 @@ stylesheet=< } } - return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); + // Browsers require that we URI-encode the contents of our data URI. (It + // seems this was a relatively recent change?) In practice, this means that we + // need to escape '#'. + return StrFormat( + fmt, graph_label, + absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}})); } string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } @@ -1011,6 +1015,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: + case HloOpcode::kTriangularSolve: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; @@ -1281,8 +1286,9 @@ namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, - int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter( + const HloInstruction* root, int64 radius, + const absl::flat_hash_set& boundary) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. absl::flat_hash_map nodes; @@ -1298,6 +1304,9 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, if (depth == radius) { continue; } + if (boundary.contains(instr)) { + continue; + } // Traverse into instr's operands. // @@ -1446,9 +1455,6 @@ string SaveGraph(const string& graph, case GraphRendererInterface::DOT_GRAPH: file_extension = ".dot"; break; - case GraphRendererInterface::TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; } string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); @@ -1486,25 +1492,27 @@ string ExportGraph(const string& graph, } // namespace +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options) { + DebugOptions default_debug_options; + return HloDotDumper(&computation, options.label, + options.debug_options ? *options.debug_options + : default_debug_options, + options.show_backend_config, options.profile, + NodeFilter()) + .Dump(); +} + string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; - string graph; - if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder(debug_options); - TF_CHECK_OK(builder.AddComputation(computation)); - CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), - &graph)); - graph_kind = GraphRendererInterface::TF_GRAPHDEF; - } else { - graph = - HloDotDumper(&computation, label, debug_options, show_backend_config, - hlo_execution_profile, NodeFilter()) - .Dump(); - graph_kind = GraphRendererInterface::DOT_GRAPH; - } + string graph = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label @@ -1512,12 +1520,13 @@ string DumpGraph(const HloComputation& computation, const string& label, return graph_url; } -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config) { +string DumpNeighborhoodAround( + const HloInstruction& node, int radius, bool show_backend_config, + const absl::flat_hash_set& boundary) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius, boundary); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 8e51454ef1cf992386cc7325e32705c08bf7712f..563cea42371d370b4c9ea739418692fd74dca799 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -26,13 +26,23 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { +// Converts a HLO module to a DOT (graphviz) graph. Returns the dot graph as +// a string. +struct DotGraphOptions { + absl::string_view label; + const DebugOptions* debug_options = nullptr; + const HloExecutionProfile* profile = nullptr; + bool show_backend_config = false; +}; +string HloComputationToDotGraph(const HloComputation& computation, + const DotGraphOptions& options); + // Abstract interface for classes that render HLO graphs (e.g. DOT graph, -// tensorflow GraphDef). +// tensorflow GraphDef) to files or services. class GraphRendererInterface { public: enum GraphKind { DOT_GRAPH, - TF_GRAPHDEF, }; virtual ~GraphRendererInterface() = default; @@ -63,8 +73,12 @@ string DumpGraph(const HloComputation& computation, const string& label, // The number of nodes dumped is controlled by the radius parameter, which // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. -string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_backend_config = false); +// +// The optional boundary specifies a set of boundary nodes, beyond which nodes +// will be omitted even if they are within the radius. +string DumpNeighborhoodAround( + const HloInstruction& node, int radius, bool show_backend_config = false, + const absl::flat_hash_set& boundary = {}); // Dumps nodes on any of the paths from `from` to `to`. If there are more than // max_nodes on all paths, restricts to the max_nodes nodes on the shortest diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc index a46a107723de30176241aae01b268a8c10d991d3..265bfdf7f989b0821a98c1f774cb408b78f348fe 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1b677bc25c1d3dfd0205c3e0dfbf1fd30c646fd4..5a6915005838c5c6f0ca0dd6563b2f17f6274a64 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -132,6 +132,14 @@ StatusOr> HloInstruction::CreateFromProto( absl::Span(fft_length)); break; } + case HloOpcode::kTriangularSolve: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Triangular solve instruction should have 2 operands but sees " + << proto.operand_ids_size(); + instruction = CreateTriangularSolve(shape, operands(0), operands(1), + proto.triangular_solve_options()); + break; + } case HloOpcode::kSend: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Send instruction should have 2 operand but sees " @@ -201,11 +209,12 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Sort instruction should one called computation but sees " + << proto.called_computation_ids_size(); auto sort_operands = all_operands(); - HloInstruction* keys = sort_operands[0]; - instruction = CreateSort( - shape, proto.dimensions(0), keys, - absl::Span(sort_operands).subspan(1)); + instruction = CreateSort(shape, proto.dimensions(0), all_operands(), + computations(0), proto.is_stable()); break; } case HloOpcode::kTranspose: @@ -295,6 +304,10 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), shape, proto.name()); + if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) { + instruction->set_parameter_replicated_at_leaf_buffers( + proto.parameter_replication().replicated_at_leaf_buffers()); + } break; case HloOpcode::kGetTupleElement: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -790,6 +803,13 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, fft_length); } +/* static */ std::unique_ptr +HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options) { + return absl::make_unique(shape, a, b, options); +} + /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, @@ -1153,9 +1173,11 @@ HloInstruction::CreateBroadcastSequence( } /* static */ std::unique_ptr HloInstruction::CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values) { - return absl::make_unique(shape, dimension, keys, values); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) { + return absl::make_unique(shape, dimension, operands, + compare, is_stable); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -1347,6 +1369,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1802,6 +1825,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kTriangularSolve: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1856,7 +1880,11 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); + return ReplaceUseWithDifferentShape(user, new_producer); +} +Status HloInstruction::ReplaceUseWithDifferentShape( + HloInstruction* user, HloInstruction* new_producer) { VLOG(3) << "Replacing uses of " << name() << " in " << user->name() << " with " << new_producer->name(); @@ -1952,6 +1980,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1971,6 +2000,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -2243,7 +2273,8 @@ std::vector HloInstruction::ExtraAttributesToString( opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kAllReduce || - opcode() == HloOpcode::kScatter) { + opcode() == HloOpcode::kScatter || + opcode() == HloOpcode::kSort) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2280,6 +2311,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kReduce: case HloOpcode::kAllReduce: case HloOpcode::kScatter: + case HloOpcode::kSort: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2585,6 +2617,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleIota(this); case HloOpcode::kGetDimensionSize: return visitor->HandleGetDimensionSize(this); + case HloOpcode::kTriangularSolve: + return visitor->HandleTriangularSolve(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3296,6 +3330,19 @@ int64 HloInstruction::parameter_number() const { return Cast(this)->parameter_number(); } +void HloInstruction::set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + return Cast(this) + ->set_parameter_replicated_at_leaf_buffers( + parameter_replicated_at_leaf_buffers); +} + +const absl::optional>& +HloInstruction::parameter_replicated_at_leaf_buffers() const { + return Cast(this) + ->parameter_replicated_at_leaf_buffers(); +} + int64 HloInstruction::tuple_index() const { return Cast(this)->tuple_index(); } @@ -3452,4 +3499,8 @@ const DomainMetadata& HloInstruction::operand_side_metadata() const { const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } + +const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { + return Cast(this)->triangular_solve_options(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c11d29d33e918a363a7df5c4ec4e53dbf407e71e..33cbb9a41bab838e02813e75e2ca6327f785b007 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -384,6 +385,14 @@ class HloInstruction { // Creates a random number generation instruction that fills a shape with // random numbers from a given distribution. + // + // The parameters to the instruction are interpreted as follows: + // + // - If `distribution` is RNG_UNIFORM, generates a number in range + // [param0, param1). + // + // - If `distribution` is RNG_NORMAL, generates a normally-distributed value + // with mean `param0` and standard deviation `param1`. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, absl::Span parameters); @@ -435,6 +444,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + static std::unique_ptr CreateTriangularSolve( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( @@ -489,7 +502,7 @@ class HloInstruction { // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor - // conssits of 0(s) in `shape`. + // consists of 0(s) in `shape`. static std::unique_ptr CreateCollectivePermute( const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs); @@ -598,7 +611,6 @@ class HloInstruction { // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, // ..., inputN.value1) // ... - // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( const Shape& shape, absl::Span operands, absl::Span init_values, @@ -671,10 +683,15 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and optional values operands. + // Creates a n-ary sort op with a 'compare' computation which is used for + // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters, + // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at + // specific index positions which should be compared, and should return a + // PRED. 'is_stable' specifies whether stable sorting is required. static std::unique_ptr CreateSort( - const Shape& shape, int64 dimension, HloInstruction* keys, - absl::Span values = {}); + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -932,6 +949,10 @@ class HloInstruction { // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); + // Same as ReplaceUseWith(), but new_producer can have a different shape. + Status ReplaceUseWithDifferentShape(HloInstruction* user, + HloInstruction* new_producer); + // Replaces the specified operand with new_operand. The old and new operands // must have compatible shapes ignoring floating-point precision. // @@ -1242,6 +1263,10 @@ class HloInstruction { // on the instruction's existing name. void UniquifyName(NameUniquer* name_uniquer); + // Clear the unique ID of the instruction so that it can be re-assigned, such + // as for the purpose of compacting the instruction unique IDs. + void ClearUniqueIdInternal() { unique_id_ = -1; } + // Set the unique id for this instruction to "id" void SetUniqueId(int id) { CHECK_EQ(unique_id_, -1); // Should not be assigned already @@ -1275,6 +1300,9 @@ class HloInstruction { backend_config_ = std::move(config_str); } + bool is_default_config() const { return is_default_config_; } + void set_default_config() { is_default_config_ = true; } + // Returns a string representation of a proto in the format used by // raw_backend_config_string. // @@ -1445,6 +1473,15 @@ class HloInstruction { // Delegates to HloParameterInstruction::parameter_number. int64 parameter_number() const; + // Delegates to + // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers); + + // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers. + const absl::optional>& + parameter_replicated_at_leaf_buffers() const; + // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; @@ -1554,6 +1591,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). + const TriangularSolveOptions& triangular_solve_options() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1720,6 +1760,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // This field is assigned to true when backend_config_ is assigned to + // a default configuration. + bool is_default_config_ = false; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 785206bf7753abaef5788365fe10217b8b74ccc6..905a6fe08b4430ad862edf0886a57c9f7e9f7977 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -201,6 +201,57 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloTriangularSolveInstruction::HloTriangularSolveInstruction( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options) + : HloInstruction(HloOpcode::kTriangularSolve, shape), + triangular_solve_options_(options) { + AppendOperand(a); + AppendOperand(b); +} + +HloInstructionProto HloTriangularSolveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_triangular_solve_options() = triangular_solve_options_; + return proto; +} + +std::vector HloTriangularSolveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("left_side=", + triangular_solve_options_.left_side() ? "true" : "false"), + StrCat("lower=", triangular_solve_options_.lower() ? "true" : "false"), + StrCat("unit_diagonal=", + triangular_solve_options_.unit_diagonal() ? "true" : "false"), + StrCat("transpose_a=", TriangularSolveOptions_Transpose_Name( + triangular_solve_options_.transpose_a()))}; +} + +bool HloTriangularSolveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + const auto& options = triangular_solve_options(); + const auto& other_options = casted_other.triangular_solve_options(); + + return options.left_side() == other_options.left_side() && + options.lower() == other_options.lower() && + options.unit_diagonal() == other_options.unit_diagonal() && + options.transpose_a() == other_options.transpose_a(); +} + +std::unique_ptr +HloTriangularSolveInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], triangular_solve_options()); +} + HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, int64 channel_id, @@ -609,14 +660,17 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( dimensions(), to_apply()); } -HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values) - : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { - AppendOperand(keys); - for (auto* value : values) { +HloSortInstruction::HloSortInstruction( + const Shape& shape, int64 dimension, + absl::Span operands, HloComputation* compare, + bool is_stable) + : HloInstruction(HloOpcode::kSort, shape), + dimensions_({dimension}), + is_stable_(is_stable) { + for (auto* value : operands) { AppendOperand(value); } + AppendComputation(compare); } HloInstructionProto HloSortInstruction::ToProto() const { @@ -624,12 +678,18 @@ HloInstructionProto HloSortInstruction::ToProto() const { for (int64 dimension : dimensions_) { proto.add_dimensions(dimension); } + proto.set_is_stable(is_stable()); return proto; } std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; + std::vector attrs; + attrs.push_back(StrCat("dimensions={", StrJoin(dimensions(), ","), "}")); + if (is_stable()) { + attrs.push_back("is_stable=true"); + } + return attrs; } bool HloSortInstruction::IdenticalSlowPath( @@ -637,15 +697,20 @@ bool HloSortInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return dimensions() == casted_other.dimensions(); + if (dimensions() != casted_other.dimensions()) { + return false; + } + if (is_stable() != casted_other.is_stable()) { + return false; + } + return eq_computations(to_apply(), other.to_apply()); } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - HloInstruction* keys = new_operands[0]; - return absl::make_unique(shape, dimensions(0), keys, - new_operands.subspan(1)); + return absl::make_unique( + shape, dimensions(0), new_operands, to_apply(), is_stable()); } HloTransposeInstruction::HloTransposeInstruction( @@ -1473,9 +1538,30 @@ HloParameterInstruction::HloParameterInstruction(int64 parameter_number, HloInstructionProto HloParameterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_parameter_number(parameter_number_); + if (parameter_replicated_at_leaf_buffers_) { + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + proto.mutable_parameter_replication()->add_replicated_at_leaf_buffers( + replicated); + } + } return proto; } +std::vector HloParameterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + if (!parameter_replicated_at_leaf_buffers_) { + return result; + } + std::vector buffers_replicated_strs; + for (bool replicated : *parameter_replicated_at_leaf_buffers_) { + buffers_replicated_strs.push_back(replicated ? "true" : "false"); + } + result.push_back(StrCat("parameter_replication={", + StrJoin(buffers_replicated_strs, ","), "}")); + return result; +} + string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1963,6 +2049,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (batch_group_count_ != casted_other.batch_group_count_) { return false; } + if (layout_constrained() != casted_other.layout_constrained()) { + return false; + } + if (layout_constrained()) { + for (int64 i = 0; i < operand_shapes_with_layout_.size(); ++i) { + if (!ShapeUtil::Equal(operand_shapes_with_layout_[i], + casted_other.operand_shapes_with_layout_[i])) { + return false; + } + } + } return custom_call_target_ == casted_other.custom_call_target_ && opaque_ == casted_other.opaque_; } @@ -1973,6 +2070,10 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { auto cloned = absl::make_unique( shape, new_operands, custom_call_target(), opaque()); + if (layout_constrained()) { + cloned->layout_constrained_ = true; + cloned->operand_shapes_with_layout_ = operand_shapes_with_layout(); + } if (window_ != nullptr) { cloned->set_window(*window_); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 1b4a94753cda8aba8d50836b9d51b7c3fd5807f6..4d23cb671f24623f56faa9b69015cef21752a799 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -131,6 +131,34 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloTriangularSolveInstruction : public HloInstruction { + public: + explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a, + HloInstruction* b, + const TriangularSolveOptions& options); + const TriangularSolveOptions& triangular_solve_options() const { + return triangular_solve_options_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + TriangularSolveOptions triangular_solve_options_; +}; + class HloSendRecvInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is @@ -418,8 +446,8 @@ class HloReduceInstruction : public HloInstruction { class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, - HloInstruction* keys, - absl::Span values = {}); + absl::Span operands, + HloComputation* compare, bool is_stable); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -432,6 +460,7 @@ class HloSortInstruction : public HloInstruction { HloInstruction* mutable_keys() { return mutable_operand(0); } // Returns the number of value operands. int64 values_count() const { return operand_count() - 1; } + bool is_stable() const { return is_stable_; } private: std::vector ExtraAttributesToStringImpl( @@ -446,6 +475,7 @@ class HloSortInstruction : public HloInstruction { HloCloneContext* context) const override; std::vector dimensions_; + bool is_stable_; }; class HloTransposeInstruction : public HloInstruction { @@ -787,10 +817,28 @@ class HloParameterInstruction : public HloInstruction { explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, const string& name); int64 parameter_number() const { return parameter_number_; } + + // Sets and gets the whether all replicas will receive the same parameter data + // for each leaf buffer in data parallelism. + void set_parameter_replicated_at_leaf_buffers( + absl::Span parameter_replicated_at_leaf_buffers) { + CHECK_EQ(ShapeUtil::GetLeafCount(shape()), + parameter_replicated_at_leaf_buffers.size()); + parameter_replicated_at_leaf_buffers_.emplace( + parameter_replicated_at_leaf_buffers.begin(), + parameter_replicated_at_leaf_buffers.end()); + } + const absl::optional>& + parameter_replicated_at_leaf_buffers() const { + return parameter_replicated_at_leaf_buffers_; + } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -804,6 +852,10 @@ class HloParameterInstruction : public HloInstruction { HloCloneContext* context) const override; int64 parameter_number_ = 0; + + // Specifies whether each buffer has the same parameter value on all replicas + // in data parallelism. + absl::optional> parameter_replicated_at_leaf_buffers_; }; class HloGetTupleElementInstruction : public HloInstruction { @@ -903,9 +955,7 @@ class HloOutfeedInstruction : public HloInstruction { HloInstruction* token_operand, absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. - const Shape& outfeed_shape() const { - return outfeed_shape_; - } + const Shape& outfeed_shape() const { return outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } // Returns a serialized representation of this instruction. diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index c1a642dfea7e464aaf93ffde1e26e07c1a4b73cd..2255383322873a39c7076e0f4f0dd541bc79014d 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -153,6 +153,8 @@ TokKind HloLexer::LexToken() { return LexPercent(); case ':': return TokKind::kColon; + case '*': + return TokKind::kAsterisk; case '[': return TokKind::kLsquare; case ']': @@ -464,6 +466,8 @@ string TokKindToString(TokKind kind) { return "kComma"; case TokKind::kColon: return "kColon"; + case TokKind::kAsterisk: + return "kAsterisk"; case TokKind::kLsquare: return "kLsquare"; case TokKind::kRsquare: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 16eed21617bc7254b67090d2b5acf9ccbd82f4ea..383fb4e862b8e32771879d055e663dc821a5c839 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -38,9 +38,10 @@ enum class TokKind { kError, // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * kLsquare, kRsquare, // [ ] kLbrace, @@ -108,7 +109,7 @@ class HloLexer { LOG(FATAL) << "This token does not have string value"; } } - tensorflow::int64 GetInt64Val() const { + int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return token_state_.int64_val; } @@ -171,7 +172,7 @@ class HloLexer { const char* token_start = nullptr; TokKind current_kind; string str_val; - tensorflow::int64 int64_val; + int64 int64_val; double decimal_val; PrimitiveType primitive_type_val; }; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 759ce6541ef144ad3f84bcb87ddabf507a034305..8322870cfd6a89fc6f863da8fd4a3576e8845cd7 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -246,6 +246,34 @@ HloModuleProto HloModule::ToProto() const { return proto; } +Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { + absl::flat_hash_set computation_names; + absl::flat_hash_set computation_ids; + absl::flat_hash_set instruction_names; + absl::flat_hash_set instruction_ids; + + for (const HloComputation* computation : computations()) { + TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) + << "Computation name is not unique: " << computation->name(); + computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); + + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) + << "Instruction name is not unique: " << instruction->name(); + instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); + } + } + return Status::OK(); +} + /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { @@ -329,28 +357,8 @@ StatusOr> HloModule::CreateFromProto( DynamicParameterBinding::CreateFromProto( proto.dynamic_parameter_binding())); - absl::flat_hash_set computation_names; - absl::flat_hash_set instruction_names; - absl::flat_hash_set computation_ids; - absl::flat_hash_set instruction_ids; - for (HloComputation* computation : module->computations()) { - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) - << "Computation name is not unique: " << computation->name(); - computation_names.insert(computation->name()); - - TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) - << "Computation id is not unique: " << computation->unique_id(); - computation_ids.insert(computation->unique_id()); - for (HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) - << "Instruction name is not unique: " << instruction->name(); - instruction_names.insert(instruction->name()); - - TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) - << "Instruction id is not unique: " << instruction->unique_id(); - instruction_ids.insert(instruction->unique_id()); - } - } + TF_RETURN_IF_ERROR( + module->CheckUniqueNamesAndIdsForComputationsAndInstructions()); if (proto.has_schedule()) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f1310e4b270898a21dbb4f86123edde4ba8993d0..b6fe6a5cdbd0934014f1152acd48c7a5973bead3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -187,6 +187,7 @@ class HloModule { std::vector MakeNonfusionComputations() const; const HloModuleConfig& config() const { return config_; } + void set_config(HloModuleConfig& config) { config_ = config; } // Return a string representation of the module. // @@ -264,6 +265,18 @@ class HloModule { const HloSchedule& schedule() const { return *schedule_; } HloSchedule& schedule() { return *schedule_; } + HloComputation* AddComputationAndUnifyNamesAndIds( + std::unique_ptr computation, bool is_entry) { + computation->ClearUniqueIdInternal(); + for (auto* instruction : computation->instructions()) { + instruction->ClearUniqueIdInternal(); + } + return AddComputationInternal(std::move(computation), is_entry, + /*uniquify_identifiers=*/true); + } + + Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 47734bc55cc00d605f4e318400be88639450343c..b877081be5775bf6c75a69ffeba28d0f2cc17f90 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -389,9 +389,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); - - if (!ContainsKey(companion_set_index_, instruction1) && - !ContainsKey(companion_set_index_, instruction2)) { + if (instruction1 == instruction2) { + return Status::OK(); + } else if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( absl::make_unique>()); auto companion_set = companion_sets_.back().get(); @@ -419,7 +420,10 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; } - companion_sets_.erase(companion_sets_.begin() + index_to_remove); + // We can't remove the set from the vector because companion_set_index_ + // references sets by their index in this vector, so we reset to nullptr + // instead. + companion_sets_[index_to_remove].reset(nullptr); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 3ed95c10504141139d83eb8679a0b8144b15ad0d..84f7f2f31339ae9e98ea2301b6e6d94fcf4dedbb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -173,7 +173,8 @@ class HloModuleGroupMetadata { // Returns the number of modules for devices (excluding the host module). int64 GetDeviceModulesCount() const; - // Returns the companion instructions for the given instruction. + // Returns the companion set for the given instruction, including the + // instruction itself. // // Precondition: IsCompanionWhile(instruction) is true. const std::vector& Companions( diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index bf9b3c811704870d9e0a36de5c38a013fba6dfe4..35626ba37541fc7a984ad05d12ebc22e9a08a550 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -137,6 +137,7 @@ namespace xla { V(kTanh, "tanh") \ V(kTrace, "trace") \ V(kTranspose, "transpose") \ + V(kTriangularSolve, "triangular-solve") \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ V(kTupleSelect, "tuple-select") \ V(kWhile, "while") \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index dd8e8ff3a52a1cdf99a2b07b83e2891f90cf85bb..f448571082e52e4b81db1c68d1e1470935386139 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -63,8 +63,7 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { // Some functions accept either a linear index or a multi-dimensional index // (used for indexing into sparse literals). -using LinearOrMultiIndex = - absl::variant>; +using LinearOrMultiIndex = absl::variant>; // Parser for the HloModule::ToString() format text. class HloParser { @@ -83,6 +82,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); StatusOr ParsePaddingConfigOnly(); @@ -129,8 +129,8 @@ class HloParser { // given value. If the literal is dense, it myst have the default layout. // // `loc` should be the source location of the value. - bool SetValueInLiteral(LocTy loc, tensorflow::int64 value, - LinearOrMultiIndex index, Literal* literal); + bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index, + Literal* literal); bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index, Literal* literal); bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index, @@ -158,9 +158,9 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; }; // The data parsed for the kDomain instruction. @@ -180,9 +180,11 @@ class HloParser { kBracedInt64ListList, kHloComputation, kFftType, + kTriangularSolveTranspose, kWindow, kConvolutionDimensionNumbers, kSharding, + kParameterReplication, kInstructionList, kSliceRanges, kPaddingConfig, @@ -247,21 +249,21 @@ class HloParser { bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseParameterReplication(ParameterReplication* parameter_replication); // Parses the metadata behind a kDOmain instruction. bool ParseDomain(DomainData* domain); // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, - std::vector* result); + const TokKind delim, std::vector* result); // 'parse_and_add_item' is an lambda to parse an element in the list and add // the parsed element to the result. It's supposed to capture the result. bool ParseList(const TokKind start, const TokKind end, const TokKind delim, @@ -276,12 +278,14 @@ class HloParser { std::vector* dynamic_dimensions); bool ParseShape(Shape* result); bool ParseLayout(Layout* layout); + bool ParseTiles(std::vector* tiles); bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); + bool ParseTriangularSolveTranspose(TriangularSolveOptions::Transpose* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParsePrecision(PrecisionConfig::Precision* result); - bool ParseInt64(tensorflow::int64* result); + bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseComplex(std::complex* result); bool ParseBool(bool* result); @@ -643,6 +647,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, std::unordered_map attrs; optional sharding; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional parameter_replication; + attrs["parameter_replication"] = {/*required=*/false, + AttrTy::kParameterReplication, + ¶meter_replication}; optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; @@ -656,7 +664,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - tensorflow::int64 parameter_number; + int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number)) { @@ -688,7 +696,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { - optional iota_dimension; + optional iota_dimension; attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || @@ -853,13 +861,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReplicaId: { - if (!ParseOperands(&operands, /*expected_size=*/1) || + if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - if (!operands.empty()) { - return false; - } instruction = builder->AddInstruction(HloInstruction::CreateReplicaId()); break; } @@ -894,17 +899,21 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; + optional is_stable = false; + attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable}; + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || dimensions->size() != 1) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], - /*values=*/absl::Span(operands).subspan(1))); + instruction = builder->AddInstruction( + HloInstruction::CreateSort(shape, dimensions->at(0), operands, + to_apply.value(), is_stable.value())); break; } case HloOpcode::kTuple: { @@ -930,7 +939,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -946,7 +955,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -964,7 +973,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -979,7 +988,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; // If the is_host_transfer attribute is not present then default to false. optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; @@ -997,7 +1006,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1080,7 +1089,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -1092,8 +1101,40 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], *fft_type, *fft_length)); break; } + case HloOpcode::kTriangularSolve: { + optional left_side; + optional lower; + optional unit_diagonal; + optional transpose_a; + attrs["left_side"] = {/*required=*/false, AttrTy::kBool, &left_side}; + attrs["lower"] = {/*required=*/false, AttrTy::kBool, &lower}; + attrs["unit_diagonal"] = {/*required=*/false, AttrTy::kBool, + &unit_diagonal}; + attrs["transpose_a"] = {/*required=*/false, + AttrTy::kTriangularSolveTranspose, &transpose_a}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + TriangularSolveOptions options; + if (left_side) { + options.set_left_side(*left_side); + } + if (lower) { + options.set_lower(*lower); + } + if (unit_diagonal) { + options.set_unit_diagonal(*unit_diagonal); + } + options.set_transpose_a( + transpose_a ? *transpose_a : TriangularSolveOptions::NO_TRANSPOSE); + instruction = + builder->AddInstruction(HloInstruction::CreateTriangularSolve( + shape, operands[0], operands[1], options)); + break; + } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1105,7 +1146,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -1120,7 +1161,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1136,7 +1177,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { @@ -1157,7 +1198,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1201,7 +1242,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; LocTy loc = lexer_.GetLoc(); @@ -1240,7 +1281,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1254,7 +1295,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -1270,7 +1311,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1287,7 +1328,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -1369,8 +1410,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1481,16 +1522,16 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; optional> operand_precision; @@ -1534,19 +1575,19 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> offset_dims; + optional> offset_dims; attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &offset_dims}; - optional> collapsed_slice_dims; + optional> collapsed_slice_dims; attrs["collapsed_slice_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; - optional> start_index_map; + optional> start_index_map; attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, &start_index_map}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> slice_sizes; + optional> slice_sizes; attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, &slice_sizes}; @@ -1568,17 +1609,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kScatter: { - optional> update_window_dims; + optional> update_window_dims; attrs["update_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; - optional> inserted_window_dims; + optional> inserted_window_dims; attrs["inserted_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; - optional> scatter_dims_to_operand_dims; + optional> scatter_dims_to_operand_dims; attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &scatter_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; @@ -1619,7 +1660,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); case HloOpcode::kGetDimensionSize: - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -1644,6 +1685,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); } + if (parameter_replication) { + int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); + const auto& replicated = + parameter_replication->replicated_at_leaf_buffers(); + if (leaf_count != replicated.size()) { + return Error(lexer_.GetLoc(), + StrCat("parameter has ", leaf_count, + " leaf buffers, but parameter_replication has ", + replicated.size(), " elements.")); + } + instruction->set_parameter_replicated_at_leaf_buffers(replicated); + } if (predecessors) { for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); @@ -1708,8 +1761,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1735,7 +1788,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - tensorflow::int64 dim; + int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1747,7 +1800,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - tensorflow::int64 device; + int64 device; if (!ParseInt64(&device)) { return false; } @@ -1791,10 +1844,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - for (tensorflow::int64 dim : tile_assignment_dimensions) { + for (int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (tensorflow::int64 device : devices) { + for (int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1803,6 +1856,32 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// parameter_replication ::= +// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}' +bool HloParser::ParseParameterReplication( + ParameterReplication* parameter_replication) { + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start parameter_replication attribute")) { + return false; + } + + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (lexer_.GetKind() == TokKind::kw_true) { + parameter_replication->add_replicated_at_leaf_buffers(true); + } else if (lexer_.GetKind() == TokKind::kw_false) { + parameter_replication->add_replicated_at_leaf_buffers(false); + } else { + return false; + } + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + + return ParseToken(TokKind::kRbrace, + "expected '}' to end parameter_replication attribute"); +} + // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' bool HloParser::ParseDomain(DomainData* domain) { @@ -1855,22 +1934,18 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(LocTy loc, tensorflow::int64 value, +bool HloParser::SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(loc, value, index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S16: - return SetValueInLiteralHelper(loc, value, index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S32: - return SetValueInLiteralHelper(loc, value, index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case S64: - return SetValueInLiteralHelper(loc, value, index, - literal); + return SetValueInLiteralHelper(loc, value, index, literal); case U8: return SetValueInLiteralHelper(loc, value, index, literal); @@ -1958,7 +2033,7 @@ bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, } // Check that the index is in range and assign into the literal - if (auto* linear_index = absl::get_if(&index)) { + if (auto* linear_index = absl::get_if(&index)) { if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) { return Error(loc, StrCat("trys to set value ", StringifyValue(value), " to a literal in shape ", @@ -2063,8 +2138,8 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { // Create a literal with the given shape in default layout. *literal = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); - tensorflow::int64 nest_level = 0; - tensorflow::int64 linear_index = 0; + int64 nest_level = 0; + int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -2072,13 +2147,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim( - elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), + elems_seen_per_dim.begin() + dim); return StrCat("[", StrJoin(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { + [](string* out, const int64& num_elems) { StrAppend(out, num_elems - 1); }), "]"); @@ -2184,7 +2259,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } else if (primitive_util::IsIntegralType(shape.element_type()) || shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); - tensorflow::int64 value; + int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -2229,9 +2304,9 @@ bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { break; } - std::vector index; + std::vector index; if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 single_index = lexer_.GetInt64Val(); + int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); index.push_back(single_index); } else { @@ -2255,7 +2330,7 @@ bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - tensorflow::int64 value; + int64 value; if (!ParseInt64(&value)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -2341,7 +2416,7 @@ bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { -std::numeric_limits::infinity() == value))) { // Skip range checking for non-finite value. } else if (std::is_unsigned::value) { - CHECK((std::is_same::value || + CHECK((std::is_same::value || std::is_same::value)) << "Unimplemented checking for ParsedElemT"; @@ -2567,24 +2642,23 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kInt64: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr) - ->emplace(result); + static_cast*>(attr_out_ptr)->emplace(result); return true; } case AttrTy::kInt32: { - tensorflow::int64 result; + int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -2624,6 +2698,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kTriangularSolveTranspose: { + TriangularSolveOptions::Transpose result; + if (!ParseTriangularSolveTranspose(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { @@ -2649,6 +2732,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kParameterReplication: { + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(parameter_replication); + return true; + } case AttrTy::kInstructionList: { std::vector result; if (!ParseInstructionNames(&result)) { @@ -2668,19 +2760,19 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } case AttrTy::kBracedInt64ListList: { - std::vector> result; + std::vector> result; auto parse_and_add_item = [&]() { - std::vector item; + std::vector item; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &item)) { return false; @@ -2692,8 +2784,7 @@ bool HloParser::ParseAttributeHelper( parse_and_add_item)) { return false; } - static_cast>>*>( - attr_out_ptr) + static_cast>>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2894,7 +2985,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( absl::string_view rhs = split2[0]; absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs.length(); + const int64 rank = lhs.length(); if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -3005,7 +3096,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -3075,9 +3166,9 @@ bool HloParser::ParseShapeList(std::vector* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; if (!ParseInt64(&i)) { return false; } @@ -3153,7 +3244,7 @@ bool HloParser::ParseParamList() { bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, std::vector* dynamic_dimensions) { auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; bool is_dynamic = false; if (lexer_.GetKind() == TokKind::kLeq) { is_dynamic = true; @@ -3170,22 +3261,108 @@ bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, parse_and_add_item); } -// layout ::= '{' int64_list '}' +// tiles +// ::= /*empty*/ +// ::= 'T' '(' dim_list ')' +// dim_list +// ::= /*empty*/ +// ::= (int64 | '*') (',' (int64 | '*'))* +bool HloParser::ParseTiles(std::vector* tiles) { + auto parse_and_add_tile_dimension = [&]() { + tensorflow::int64 i; + if (ParseInt64(&i)) { + tiles->back().add_dimensions(i); + return true; + } + if (lexer_.GetKind() == TokKind::kAsterisk) { + tiles->back().add_dimensions(Tile::kCombineDimension); + lexer_.Lex(); + return true; + } + return false; + }; + + do { + tiles->push_back(Tile()); + if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma, + parse_and_add_tile_dimension)) { + return false; + } + } while (lexer_.GetKind() == TokKind::kLparen); + return true; +} + +// layout ::= '{' int64_list (':' tiles element_size_in_bits)? '}' +// element_size_in_bits +// ::= /*empty*/ +// ::= 'E' '(' int64 ')' bool HloParser::ParseLayout(Layout* layout) { std::vector minor_to_major; + std::vector tiles; + tensorflow::int64 element_size_in_bits = 0; + auto parse_and_add_item = [&]() { - tensorflow::int64 i; + int64 i; if (!ParseInt64(&i)) { return false; } minor_to_major.push_back(i); return true; }; - if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, - parse_and_add_item)) { + + if (!ParseToken(TokKind::kLbrace, + StrCat("expects layout to start with ", + TokKindToString(TokKind::kLbrace)))) { return false; } - *layout = LayoutUtil::MakeLayout(minor_to_major); + if (lexer_.GetKind() != TokKind::kRbrace) { + if (lexer_.GetKind() == TokKind::kInt) { + // Parse minor to major. + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + + if (lexer_.GetKind() == TokKind::kColon) { + lexer_.Lex(); + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") { + lexer_.Lex(); + ParseTiles(&tiles); + } + + if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") { + // Parse element size in bits. + lexer_.Lex(); + if (!ParseToken(TokKind::kLparen, + StrCat("expects element size in bits to start with ", + TokKindToString(TokKind::kLparen)))) { + return false; + } + if (!ParseInt64(&element_size_in_bits)) { + return false; + } + if (!ParseToken(TokKind::kRparen, + StrCat("expects element size in bits to end with ", + TokKindToString(TokKind::kRparen)))) { + return false; + } + } + } + } + if (!ParseToken(TokKind::kRbrace, + StrCat("expects layout to end with ", + TokKindToString(TokKind::kRbrace)))) { + return false; + } + + std::vector vec_tiles(tiles.size()); + for (int i = 0; i < tiles.size(); i++) { + vec_tiles[i] = Tile(tiles[i]); + } + *layout = + LayoutUtil::MakeLayout(minor_to_major, vec_tiles, element_size_in_bits); return true; } @@ -3237,7 +3414,7 @@ bool HloParser::ParseShape(Shape* result) { lexer_.Lex(); const string message = "expects a brace-bracketed integer for sparse layout"; - tensorflow::int64 max_sparse_elements; + int64 max_sparse_elements; if (!ParseToken(TokKind::kLbrace, message) || !ParseInt64(&max_sparse_elements) || !ParseToken(TokKind::kRbrace, message)) { @@ -3257,13 +3434,20 @@ bool HloParser::ParseShape(Shape* result) { // // The open brace could either be the start of a computation or the start of a // layout for the f32[123] shape. We consider it the start of a layout if the - // next token after the open brace is a integer + // next token after the open brace is an integer or a colon. if (lexer_.GetKind() == TokKind::kLbrace && - lexer_.LookAhead() == TokKind::kInt) { + (lexer_.LookAhead() == TokKind::kInt || + lexer_.LookAhead() == TokKind::kColon)) { Layout layout; if (!ParseLayout(&layout)) { return false; } + if (layout.minor_to_major_size() != result->rank()) { + return Error( + lexer_.GetLoc(), + StrFormat("Dimensions size is %ld, but minor to major size is %ld.", + result->rank(), layout.minor_to_major_size())); + } *result->mutable_layout() = layout; } return true; @@ -3306,15 +3490,14 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, - std::vector* result) { +bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { - tensorflow::int64 number; + int64 number; if (!ParseInt64(&number)) { return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } @@ -3333,8 +3516,7 @@ bool HloParser::ParseDxD(const string& name, return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad( - std::vector>* pad) { +bool HloParser::ParseWindowPad(std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -3344,7 +3526,7 @@ bool HloParser::ParseWindowPad( } string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector low_high; + std::vector low_high; if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -3367,7 +3549,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -3389,7 +3571,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -3441,6 +3623,22 @@ bool HloParser::ParseFftType(FftType* result) { return true; } +bool HloParser::ParseTriangularSolveTranspose( + TriangularSolveOptions::Transpose* result) { + VLOG(1) << "ParseTriangularSolveTranspose"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects triangular solve transpose type"); + } + string val = lexer_.GetStrVal(); + if (!TriangularSolveOptions_Transpose_Parse(val, result) || + !TriangularSolveOptions_Transpose_IsValid(*result)) { + return TokenError( + StrFormat("expects triangular solve transpose type but sees: %s", val)); + } + lexer_.Lex(); + return true; +} + bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { @@ -3492,7 +3690,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { return true; } -bool HloParser::ParseInt64(tensorflow::int64* result) { +bool HloParser::ParseInt64(int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -3644,6 +3842,21 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr> HloParser::ParseParameterReplicationOnly() { + lexer_.Lex(); + ParameterReplication parameter_replication; + if (!ParseParameterReplication(¶meter_replication)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after parameter replication"); + } + return std::vector( + parameter_replication.replicated_at_leaf_buffers().begin(), + parameter_replication.replicated_at_leaf_buffers().end()); +} + StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; @@ -3759,6 +3972,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr> ParseParameterReplication(absl::string_view str) { + HloParser parser(str); + return parser.ParseParameterReplicationOnly(); +} + StatusOr ParseWindow(absl::string_view str) { HloParser parser(str); return parser.ParseWindowOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 450a54c54c156c2ae27475d145a8e83dc841b431..a96260b4d75e515a4cb23d315444142cae1b9587 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,11 +44,16 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, -// e.g., "{replicated}". +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +StatusOr> ParseParameterReplication(absl::string_view str); + // Parses the result of window_util::ToString(const Window&). StatusOr ParseWindow(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b525f66f9bc837f720531b8828436ebb9c1f6b31..8e3f1e44b9562334130aa565ed447a78899fad53 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -63,6 +63,19 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } +)" +}, +// parameter replication +{ +"ParamReplication", +R"(HloModule param_replication_module + +ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) { + %a = f32[] parameter(0), parameter_replication={true} + %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true} + ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b) +} + )" }, // pred constant @@ -1047,9 +1060,15 @@ ENTRY ReducePrecision { "SortKey", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { x = f32[1024]{0} parameter(0) - ROOT sorted = f32[1024]{0} sort(x), dimensions={0} + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare } )" @@ -1059,10 +1078,18 @@ ENTRY Sort { "SortKeyValue", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1072,9 +1099,15 @@ ENTRY Sort { "SortKeyR2", R"(HloModule sort +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { x = f32[1024,16]{0,1} parameter(0) - ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0} + ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare } )" @@ -1084,10 +1117,18 @@ ENTRY Sort { "SortKeyValueR2", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values = s32[1024,16]{0,1} parameter(1) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1097,12 +1138,42 @@ ENTRY Sort { "SortManyValues", R"(HloModule sort +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + p.3.lhs = f32[] parameter(6) + p.3.rhs = f32[] parameter(7) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024,16]{0,1} parameter(0) values.0 = s32[1024,16]{0,1} parameter(1) values.1 = u32[1024,16]{0,1} parameter(2) values.2 = f32[1024,16]{0,1} parameter(3) - ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare +} + +)" +}, +// Sort (Key) is_stable=true +{ +"SortKeyStable", +R"(HloModule sort + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare } )" @@ -1280,6 +1351,17 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)" +}, +// replica-id +{ +"ReplicaId", +R"(HloModule replica-id + +ENTRY Replica-id { + ROOT replica-id = u32[] replica-id() +} + )" }, // Iota @@ -1309,10 +1391,18 @@ ENTRY Computation { "ScheduledModule", R"(HloModule scheduled_module, is_scheduled=true +compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY Sort { keys = f32[1024]{0} parameter(0) values = s32[1024]{0} parameter(1) - ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare } )" @@ -1396,7 +1486,7 @@ class HloParameterizedParserTest protected: // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and - // checks that the it equals the original string. + // checks that it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -2515,6 +2605,60 @@ TEST_F(HloParserTest, ParseShapeStringWithLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) { + // One tile. + string shape_string = "f32[123,456]{0,1:T(2,128)}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); + Shape expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with negative dimension size for combining dimensions. + shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = + ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2}, + {Tile({2, Tile::kCombineDimension, 128})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Two tiles. + shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout( + BF16, {123, 456, 789}, {2, 1, 0}, + {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})}); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Tile with element size in bits. + shape_string = "pred[123,456]{1,0:T(2,128)E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, + {Tile({2, 128})}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Element size in bits without tile. + shape_string = "pred[123,456]{1,0:E(1)}"; + TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string)); + expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1); + EXPECT_EQ(expected, actual) + << "expected: " << ShapeUtil::HumanStringWithLayout(expected) + << "actual: " << ShapeUtil::HumanStringWithLayout(actual); + + // Wrong minor_to_major. + shape_string = "f32[123,456,789]{1:T(2, * , 128)}"; + auto result = ParseShape(shape_string); + ExpectHasSubstr(result.status().error_message(), + "Dimensions size is 3, but minor to major size is 1."); +} + TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) { string shape_string = "f32[123,456]sparse{10}"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string)); @@ -2579,5 +2723,16 @@ TEST_F(HloParserTest, NegativeParameterNumber) { ::testing::HasSubstr("parameter number must be >= 0")); } +TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) { + const string hlo_string = + "par0 = (f32[3,5], f32[]) parameter(0), " + "parameter_replication={true,false,true}"; + auto result = ParseHloString(hlo_string); + ASSERT_FALSE(result.status().ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("parameter has 2 leaf buffers, but " + "parameter_replication has 3 elements")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 0fced7f15bdaf1dbe349e3b0fc6ada68393c6512..b7f507b1184dbe021effc1102a68040286480ed2 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -77,28 +77,51 @@ std::unique_ptr HloReachabilityMap::Build( const HloComputation* computation) { const auto& all = computation->MakeInstructionPostOrder(); auto result = absl::make_unique(all); - auto channel_dependency_map = computation->ComputeChannelDependencies(); + auto channel_group = computation->ComputeChannelDependencies(); - std::vector inputs; for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); + std::vector inputs; + const auto add_input = [&channel_group, &inputs](HloInstruction* input) { + inputs.push_back(input); + if (input->opcode() == HloOpcode::kAllReduce && input->all_reduce_id()) { + auto it = channel_group.find(*input->all_reduce_id()); + if (it != channel_group.end()) { + inputs.insert(inputs.end(), it->second.begin(), it->second.end()); + } + } + }; + + const auto add_dependencies = [&add_input](const HloInstruction* hlo) { + for (HloInstruction* operand : hlo->operands()) { + add_input(operand); + } + for (HloInstruction* predecessor : hlo->control_predecessors()) { + add_input(predecessor); + } + }; + + add_dependencies(hlo); switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(hlo->channel_id()); + if (it != channel_group.end()) { + for (HloInstruction* channel : it->second) { + if (channel->opcode() == HloOpcode::kSend) { + add_input(channel); + } + } } break; } case HloOpcode::kAllReduce: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); + auto it = channel_group.find(all_reduce_id.value()); + if (it != channel_group.end()) { + for (HloInstruction* all_reduce : it->second) { + add_dependencies(all_reduce); + } } } break; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index d7d66ae1c4592723ca991d5ee971fa72cc1af90a..5a5401e351384867016a3a9addfd43d57091848c 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -168,6 +168,35 @@ StatusOr HloRunner::Execute(std::unique_ptr module, /*profile=*/profile); } +StatusOr HloRunner::Execute( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*executable=*/executable.get(), + /*arguments=*/argument_buffers, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr HloRunner::Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return Execute( + /*module=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, @@ -206,7 +235,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { // Get service run options. @@ -225,7 +254,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( } StatusOr HloRunner::ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile) { std::vector argument_pointers; diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index bb792cf8c9825ff67ca33bbcf2c3c32b1a0ecb85..098989cd4c78fb5ad57cd6700fbf99c50064f225 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -60,7 +60,7 @@ class HloRunner { // The number of times the infeed literal should be fed to the HLO module. // For a clean exit, this should match the iterations-per-loop parameter // used when generating the HLO module proto (that is usually the main - // while bounary counter). A value higher then iterations-per-loop would + // while boundary counter). A value higher then iterations-per-loop would // lead to infeed threads feeding to a gone computation, while a lower // value would trigger a stuck ExecuteReplicated() call (the computation // will be trying to infeed data which will never come). @@ -124,6 +124,14 @@ class HloRunner { bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr Execute(std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( @@ -136,13 +144,16 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // In the following two calls, "executable" is not a unique_ptr to allow + // reuse of the Executable. This call may update the profile information in + // *executable. StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( - std::unique_ptr executable, + Executable* executable, const absl::Span arguments, ExecutionProfile* profile = nullptr); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 37cc146bd7a6f2aef9373bd4afd8572ffac6473c..f1d7e60f2b5a68408f6d428a0ec47fba3c9c4f12 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -96,13 +96,13 @@ string HloSharding::ToString() const { if (replicated_) { return "{replicated}"; - } else if (maximal_) { + } + if (maximal_) { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); - } else { - return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), - "]", StrJoin(tile_assignment_, ","), "}"); } + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", + StrJoin(tile_assignment_, ","), "}"); } bool HloSharding::UsesDevice(int64 device) const { @@ -328,8 +328,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, status = tensorflow::errors::InvalidArgument( StrCat("core ", core, " is not unique in tile assignment")); } + seen_cores.insert(core); } - seen_cores.insert(core); }); if (!status.ok()) { return status; @@ -347,7 +347,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } - // The correct constructor have to be used to create tile maximal shardings. + // The correct constructor has to be used to create tile maximal shardings. if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( "Tile assignment only contains a single device. If a replicated " diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 5789ae09988d2a85247c5b8c037a172b3699f3b7..dd57ea83f1cb33aa052facb607bc040d2e708633 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -118,7 +118,7 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; - // Retrieves an histogram of the devices used by the sharding. The returned + // Retrieves a histogram of the devices used by the sharding. The returned // map has the device number as key, and the occurrence count as value. // If a sharding does not have a device, it will not be incuded in the // histogram. The count argument, if not nullptr, will receive the total @@ -260,6 +260,19 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; + // This field is only used if replicated_ is false. If maximal_ is true, then + // the field contains a rank 1 array with a single element, which is the + // device the HLO is assigned to. If maximal_ is false, the field contains an + // array with the same rank as the corresponding HLO. The dimension sizes of + // the array describe the number of ways the HLO is partitioned along each + // dimension. The values of the array specify which device each tile of + // the HLO is assigned to. The index of each value determines which tile it + // takes. + // For example, {{{2, 3}}, {{5, 7}}} (whose ToString representation is + // "{devices=[2,1,2]2,3,5,7}"), means that dimension 1 is split two way and + // dimension 3 is split 2 way. Core 5, whose index is [2,1,1] will take the + // tile that contains the 2nd half of dimension 1 and the 1st half of + // dimension 3. Array tile_assignment_; // Only non-empty when tuple_ is true. If a tuple is empty then one entry is // present for the root. This is a flattened list of all the leaf shardings in diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 80634677e78e4a35dcb9bf7de018a88122c3c030..9e234e025586ff14f99da73afc5610c627303a36 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -84,7 +84,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail because of more devices used then `num_device`. + // Test should fail because of more devices used than `num_device`. HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc deleted file mode 100644 index c1f69db74eafb7743e85f499f2f4828ed0375501..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -LIcensed under the Apache License, Version 2.0 (the "License"); -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using absl::StrAppend; -using absl::StrCat; -using tensorflow::GraphDef; -using tensorflow::NodeDef; -using tensorflow::TensorShapeProto; - -string GetOpDefName(const HloInstruction* instruction) { - string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok - name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); - - if (instruction->opcode() == HloOpcode::kFusion) { - string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, absl::string_view(fusion_name).substr(1)); - } - return name; -} - -TensorShapeProto GetTensorShape(const HloInstruction* instruction) { - TensorShapeProto tensor_shape; - const Shape& shape = instruction->shape(); - for (auto dim : shape.dimensions()) { - tensor_shape.add_dim()->set_size(dim); - } - return tensor_shape; -} - -string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } - -void CleanNodeName(string* name) { - name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); - const string chars_to_replace = "<>[]"; - auto pred = [&](char c) { - return absl::c_linear_search(chars_to_replace, c); - }; - std::replace_if(name->begin(), name->end(), pred, '_'); -} - -} // namespace - -HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) - : debug_options_(debug_options) {} - -Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { - VLOG(2) << "Adding computation " << computation.name(); - for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto* instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - } - for (auto* instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction)); - } - return Status::OK(); -} - -const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } - -const string& HloTfGraphBuilder::GetNodeNameForInstruction( - const HloInstruction* instruction) { - if (ContainsKey(instruction_to_node_name_, instruction)) { - return instruction_to_node_name_[instruction]; - } - auto append = [](string* str, const string& other) { - if (str->empty()) { - *str = other; - } else if (!other.empty()) { - StrAppend(str, "/", other); - } - }; - string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes()) { - auto device = instruction->sharding_unique_device(); - if (device) { - node_name = StrCat("dev", *device); - } - } - // If an instruction is fused, put it in the subgraph of the fusion; - // otherwise, put it in the computation subgraph. - const HloComputation* computation = instruction->parent(); - if (computation->IsFusionComputation()) { - append(&node_name, - GetNodeNameForInstruction(computation->FusionInstruction())); - } else { - append(&node_name, computation->name()); - if (!instruction->metadata().op_name().empty()) { - // Always make computations contain TF ops but not the other way around. - append(&node_name, instruction->metadata().op_name()); - } - } - string instruction_name = instruction->name(); - if (instruction->opcode() == HloOpcode::kParameter) { - StrAppend(&instruction_name, ".", instruction->parameter_number()); - } - append(&node_name, instruction_name); - CleanNodeName(&node_name); - auto ret = - instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); - CHECK(ret.second); - return ret.first->second; -} - -void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, - NodeDef* node_def) const { - auto& attrs = *node_def->mutable_attr(); - - // Set the number of arguments for instructions that have variadic operands. - if (HloOpcodeIsVariadic(instruction->opcode())) { - tensorflow::AttrValue attr_value; - attr_value.set_i(instruction->operands().size()); - attrs["arg_num"] = attr_value; - } - - // Set the node type. - attrs["type"].set_s( - xla::PrimitiveType_Name(instruction->shape().element_type())); - - // Set the framework op (e.g. Tensorflow op) that generated this XLA op. - attrs["tf_op_type"].set_s(instruction->metadata().op_type()); - attrs["tf_op_name"].set_s(instruction->metadata().op_name()); - - // Set the shape of the output tensor. "_output_shapes" is a special attribute - // name used by Tensorboard for shapes of output tensors. - tensorflow::AttrValue shapes; - *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); - attrs["_output_shapes"] = shapes; - - // Set the layout. - if (LayoutUtil::HasLayout(instruction->shape())) { - string layout_string; - if (instruction->shape().IsTuple()) { - // For tuples, emit the full shape because the layout of a tuple is not - // represented in a single Layout field. - layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); - } else { - layout_string = StrCat( - "{", - absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), - "}"); - } - attrs["layout"].set_s(layout_string); - } - - // Set op-specific attributes. - switch (instruction->opcode()) { - case HloOpcode::kConcatenate: - case HloOpcode::kBroadcast: - case HloOpcode::kReduce: - case HloOpcode::kReverse: - case HloOpcode::kTranspose: - for (auto dim : instruction->dimensions()) { - attrs["dims"].mutable_list()->add_i(dim); - } - break; - case HloOpcode::kGetTupleElement: - attrs["index"].set_i(instruction->tuple_index()); - break; - case HloOpcode::kRng: - attrs["dist"].set_s( - RandomDistribution_Name(instruction->random_distribution())); - break; - case HloOpcode::kConstant: - if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s(instruction->literal().GetAsString({})); - } - break; - case HloOpcode::kCustomCall: - attrs["custom_call_target"].set_s(instruction->custom_call_target()); - break; - case HloOpcode::kSend: - case HloOpcode::kRecv: - attrs["channel_id"].set_i(instruction->channel_id()); - break; - default: - break; - } -} - -Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { - if (!visited_instructions_.insert(instruction).second) { - // Skip instructions that have already been added. - return Status::OK(); - } - - NodeDef* node_def = graph_def_.add_node(); - node_def->set_name(GetNodeNameForInstruction(instruction)); - node_def->set_op(GetOpDefName(instruction)); - - auto device = instruction->sharding_unique_device(); - if (device) { - node_def->set_device(GetDeviceName(*device)); - } - SetNodeAttrs(instruction, node_def); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto* fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); - } - } - // Add all edges including control edges. - for (unsigned i = 0; i < instruction->operands().size(); ++i) { - *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); - } - // Called computations are control dependencies. - for (const auto* called_computation : instruction->called_computations()) { - *node_def->add_input() = StrCat( - "^", GetNodeNameForInstruction(called_computation->root_instruction())); - } - return Status::OK(); -} - -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h deleted file mode 100644 index c4876b852e32d34693202f4023aa20ad2b301ffd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" - -namespace xla { -namespace hlo_graph_dumper { - -// This constructs a tensorflow graph for HLO computations. -class HloTfGraphBuilder { - public: - HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); - - // Adds a computation to the graph. - Status AddComputation(const HloComputation& computation); - - const tensorflow::GraphDef& GetGraphDef() const; - - private: - // Gets the node name of an instruction. The node name is hierarchical. For - // example, if an instruction is fused, it will be put in a subgraph of the - // fusion instruction. - const string& GetNodeNameForInstruction(const HloInstruction* instruction); - - void SetNodeAttrs(const HloInstruction* instruction, - tensorflow::NodeDef* node_def) const; - - Status AddInstruction(const HloInstruction* instruction); - - DebugOptions debug_options_; - tensorflow::GraphDef graph_def_; - // This records instructions that have been visited. - std::unordered_set visited_instructions_; - // A cache that maps instruction to the node name. - std::unordered_map instruction_to_node_name_; -}; - -} // namespace hlo_graph_dumper -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc deleted file mode 100644 index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace xla { -namespace hlo_graph_dumper { -namespace { - -using ::tensorflow::GraphDef; - -class HloTfGraphBuilderTest : public HloTestBase { - protected: - HloTfGraphBuilderTest() {} - HloTfGraphBuilder generator_; - - // Create a computation which takes a scalar and returns its negation. - std::unique_ptr CreateNegateComputation() { - auto builder = HloComputation::Builder("Negate"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - return builder.Build(); - } - - // Creates a computation which calls map with the given computation. - std::unique_ptr CreateMapComputation( - HloComputation *map_computation) { - auto builder = HloComputation::Builder("Map"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map_computation)); - return builder.Build(); - } - Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); -}; - -static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, - const string &attr_name) { - auto attr = node.attr().find(attr_name); - CHECK(attr != node.attr().end()); - return attr->second; -} - -TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { - auto builder = HloComputation::Builder("Concatenate"); - Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, shape, "param1")); - builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - const auto &node = graph_def.node(2); - EXPECT_EQ(node.name(), "Concatenate/concatenate"); - - // Check dimensions. - auto dims_value = GetNodeAttr(node, "dims"); - EXPECT_EQ(dims_value.list().i_size(), 1); - EXPECT_EQ(dims_value.list().i(0), 1); - - // Check shapes. - auto shape_value = GetNodeAttr(node, "_output_shapes"); - EXPECT_EQ(shape_value.list().shape_size(), 1); - EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); - EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); -} - -TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { - auto builder = HloComputation::Builder("Const"); - HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - OpMetadata metadata; - metadata.set_op_name("x"); - metadata.set_op_type("y"); - instruction->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 1); - const auto &node = graph_def.node(0); - EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); - EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); - EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); -} - -TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { - auto negate_computation = CreateNegateComputation(); - TF_CHECK_OK(generator_.AddComputation(*negate_computation)); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 2); - EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); - EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); - EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); - EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); - EXPECT_EQ(graph_def.node(1).input_size(), 1); - EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); -} - -TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { - auto builder = HloComputation::Builder("GE"); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32_, "param1")); - auto ge = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); - OpMetadata metadata; - metadata.set_op_name("x/y"); - metadata.set_op_type("Y"); - ge->set_metadata(metadata); - TF_CHECK_OK(generator_.AddComputation(*builder.Build())); - GraphDef graph_def = generator_.GetGraphDef(); - EXPECT_EQ(graph_def.node_size(), 3); - EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); - EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); - EXPECT_EQ(graph_def.node(2).input_size(), 2); - EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); - EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); -} - -TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { - // Create computations with a diamond-shaped callgraph. - auto negate_computation = CreateNegateComputation(); - auto map1_computation = CreateMapComputation(negate_computation.get()); - auto map2_computation = CreateMapComputation(negate_computation.get()); - - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32_, "param0")); - auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); - auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); - auto computation = builder.Build(); - TF_CHECK_OK(generator_.AddComputation(*computation)); - EXPECT_GT(generator_.GetGraphDef().node_size(), 0); -} - -} // namespace -} // namespace hlo_graph_dumper -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 4caaa5a32b1e213ff475591e32809f744bcb86ad..97e6ea9dad04238c2e6f1a49fba6d880ef3169c2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -50,6 +50,7 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kReduceWindow: case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSort: case HloOpcode::kFusion: return true; default: @@ -167,6 +168,15 @@ Status ShapeVerifier::HandleFft(HloInstruction* fft) { return CheckShape(fft, expected); } +Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 2)); + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferTriangularSolveShape( + hlo->operand(0)->shape(), hlo->operand(1)->shape(), + hlo->triangular_solve_options())); + return CheckShape(hlo, expected); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { @@ -327,13 +337,48 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError("Expected at least 1 operand for %s instruction: %s", HloOpcodeString(sort->opcode()), sort->ToString()); } + HloComputation* compare = sort->to_apply(); + + // Check that the 'compare' computation returns a PRED. + Shape compare_shape = compare->root_instruction()->shape(); + if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) { + return InternalError( + "The Sort compare computation shape does not lead to a scalar " + "predicate shape: %s", + StringifyShape(compare_shape)); + } + + // Check that the number of parameters of the 'compare' computation is + // correct. + TF_RETURN_IF_ERROR( + CheckParameterCount(sort, compare, sort->operand_count() * 2)); + + // Verify that the operands of the compare computation have the correct scalar + // shapes. + for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters(); + ++parameter_idx) { + int64 operand_idx = parameter_idx / 2; + Shape expected_scalar_shape = ShapeUtil::MakeShape( + sort->operand(operand_idx)->shape().element_type(), {}); + Shape actual_parameter_shape = + compare->parameter_instruction(parameter_idx)->shape(); + if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape, + actual_parameter_shape)) { + return InternalError( + "Expected the %lld-th parameter of the compare computation of sort " + "to have shape %s, but got %s", + parameter_idx, StringifyShape(expected_scalar_shape), + StringifyShape(actual_parameter_shape)); + } + } + + // Verify that all operand shapes have the same dimensions. for (int64 operand = 1; operand < sort->operand_count(); ++operand) { if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), sort->operand(operand)->shape())) { return InternalError( - "Expected sort to have to have the same dimensions for the keys " - "and the values. Keys shape is: %s\n, Values shape (operand index " - "%lld) is: %s", + "Expected sort to have to have the same dimensions for all operands. " + "First operand shape is: %s\n, shape (operand index %lld) is: %s", StringifyShape(sort->operand(0)->shape()), operand, StringifyShape(sort->operand(operand)->shape())); } @@ -376,6 +421,24 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { get_tuple_element->tuple_index())); } +namespace { +Status SameElementTypesForOperandsAndToApplyParameters( + const HloInstruction& instruction, int64 num_operands_to_check) { + const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape(); + for (int i = 0; i < num_operands_to_check; ++i) { + const Shape& parameter_shape = to_apply.parameters(i); + const Shape& operand_shape = instruction.operands()[i]->shape(); + if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) { + return InvalidArgument( + "Shape mismatch between to_apply computation" + " parameter and operand %d in %s.", + i, instruction.ToString().c_str()); + } + } + return Status::OK(); +} +} // namespace + Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { if (reduce->operand_count() % 2 != 0) { return InternalError( @@ -387,9 +450,15 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(reduce, ShapeInference::InferReduceShape( - operand_shapes, reduce->dimensions(), - reduce->to_apply()->ComputeProgramShape())); + TF_RETURN_IF_ERROR( + CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *reduce, reduce->operands().size() - 1); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { @@ -545,19 +614,31 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { // arbitrary map dimensions. std::vector map_dims(max_operand_rank); std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape(map, ShapeInference::InferMapShape( - operand_shapes, - map->to_apply()->ComputeProgramShape(), map_dims)); + + TF_RETURN_IF_ERROR(CheckShape( + map, + ShapeInference::InferMapShape( + operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters( + *map, map->operands().size()); } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); - return CheckShape( + TF_RETURN_IF_ERROR(CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( reduce_window->operand(0)->shape(), reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); + reduce_window->to_apply()->ComputeProgramShape()))); + + return allow_mixed_precision_ + ? Status::OK() + : SameElementTypesForOperandsAndToApplyParameters(*reduce_window, + 1); } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { @@ -1244,8 +1325,8 @@ Status CheckFusionInstruction(HloInstruction* fusion) { return Status::OK(); } -// Checks that the non-scalar operand shapes are compatible to the output -// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +// Checks that the operand shapes are compatible to the output shape, i.e., +// that there are no implicit broadcasts. Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index facb76a124a4166e2a29c34f01194c9ebb62498b..a9b5e9a3e6eec19e125188a192694fcaadfe2322 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,6 +52,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; + Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 4f69bd155b8713041ba539098808125956e86259..523890b3c7268c06cdb6aaa67749f26a1cb62855 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -552,5 +552,67 @@ TEST_F(HloVerifierTest, IotaNonArrayResult) { HasSubstr("does not support non-array result")); } +static const char* const kMapOperandComputationMismatchHlo = R"( + HloModule MapOperandComputationMismatch + + Computation { + param0 = f32[] parameter(0) + constant = f32[] constant(1) + ROOT add = f32[] add(param0, constant) + } + + ENTRY kernelEntry { + param = f64[] parameter(0) + ROOT map = f32[] map(param), dimensions={}, to_apply=Computation +})"; + +TEST_F(HloVerifierTest, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr( + "Shape mismatch between to_apply computation parameter and operand")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kMapOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +static const char* const kReduceOperandComputationMismatchHlo = R"( + HloModule ReduceOperandComputationMismatch + computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY kernelEntry { + arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0) + constant = f16[] constant(0) + reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation + })"; + +TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to f32[64]")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kReduceOperandComputationMismatchHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 76bf48870d55e82497ba5f63e9e2e2a322cb330e..c5d32a4b9ad8c708ec0870173fa72320238e8464 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -namespace gtl = ::tensorflow::gtl; namespace { using Analysis = IndexedArrayAnalysis; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index d4794acb2f463c4cf8ce5e969f221d52e3742453..a5767774f2dcebf8fd309823d48dcc3269b3f594 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -158,6 +158,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: + case HloOpcode::kTriangularSolve: case HloOpcode::kWhile: case HloOpcode::kGetDimensionSize: return true; diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index a305c6e8005045f7dbca3b8099a3b8ddebb092af..8cd936268994c2a25c2c0debe0a003d1d05cbd0b 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 0827b1daf89bebb68c045784ef2b9da677792880..792773c676984aa280c1b20cb7fd0fc7c9425f6c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -79,6 +80,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 9376a3c8f8963551a89dcedd77068a39ffd05301..2dd9e055503e01f5c1ace5e6cd3dc64012828b71 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2070,6 +2070,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: + case HloOpcode::kTriangularSolve: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: return false; diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index c66eaec8fb0e4c03f6967fec0cf0ae9661cdf470..3acceccfa556103c15fe229c41e96e618ac59c80 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -113,20 +113,10 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, Shape output_shape = output_array.GetShape(); Shape update_shape = update_array.GetShape(); - IndexGenerator start_indices_generator; - // TODO(b/118437727): Remove the R1 path, and rename the variables. - if (start_indices_array.GetShape().rank() == 1) { - start_indices_generator = [&](int64 index) { - return start_indices_array.EmitReadArrayElement( - IrArray::Index({b->getInt64(index)}), b); - }; - } else { - start_indices_generator = [&](int64 index) { - return operand_arrays[2 + index].EmitReadArrayElement( - IrArray::Index(b->getInt64Ty()), b); - }; - } - + IndexGenerator start_indices_generator = [&](int64 index) { + return operand_arrays[2 + index].EmitReadArrayElement( + IrArray::Index(b->getInt64Ty()), b); + }; ElementGenerator update_array_generator = [&](const IrArray::Index& index) { return update_array.EmitReadArrayElement(index, b); }; @@ -178,21 +168,11 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); - // TODO(b/118437727): Remove the R1 path, and rename the variables. - IndexGenerator start_indices_generator; - if (start_indices->shape().rank() == 1) { - start_indices_generator = [&](int64 index) { - return fused_emitter.GetGenerator(start_indices)( - IrArray::Index({b->getInt64(index)})); - }; - } else { - start_indices_generator = [&](int64 index) { - ElementGenerator element_generator = - fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); - return element_generator(IrArray::Index(b->getInt64Ty())); - }; - } - + IndexGenerator start_indices_generator = [&](int64 index) { + ElementGenerator element_generator = + fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index)); + return element_generator(IrArray::Index(b->getInt64Ty())); + }; bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( update_shape, start_indices_generator, is_signed, update_array_generator, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index cf5083e8c13b9485035923895cec1ad05049c644..02c719502ee7b0a732ae74acec364f89d51ae0c1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -269,6 +269,11 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpUNE(std::forward(args)...); } + template + llvm::Value* FCmpUNO(Args&&... args) { + return mixin_builder()->CreateFCmpUNO(std::forward(args)...); + } + template llvm::Value* FDiv(Args&&... args) { return mixin_builder()->CreateFDiv(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index fe320bbe727111fbc986cc1fbc217feed74d30f1..3a35405a2da0af386e01bb48bed56ad194048543 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0dc120e0b0df47f261435f490a8459b49d989b53..a689881e65ec3a7ddf606c36bdd64b749cfe358e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 89b6a36f96beedbcb7322e6164ac59221650d3d8..d71addec9b7317dfe16e9d7e5380c3cfda0b8c06 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -45,13 +46,14 @@ namespace llvm_ir { namespace { // Adds the inner comparison loop body where we compare elements. -void EmitCompareLoopBody( - int64 iteration_bound, PrimitiveType key_type, int64 num_values, - int64 iota_values_parameter_index, llvm::Value* element_pair_index, +Status EmitCompareLoopBody( + int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, - std::function read_element, + std::function + element_address, std::function write_element, + const EmitCallToNestedComputationCallback& emit_compare_callback, llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { auto index_typed_constant = [&](int64 value) { return llvm::ConstantInt::get(index_type, value); @@ -108,74 +110,44 @@ void EmitCompareLoopBody( // if (is_smaller_index && index_is_inbounds) KernelSupportLibrary ksl(b); - ksl.If("smaller_comparison_index", do_comparison, [&]() { - auto key1 = read_element(0, current_keys_index); - auto key2 = read_element(0, compare_keys_index); - auto compare_key1 = key1; - auto compare_key2 = key2; - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the - // sort has a predictable behavior in the presence of NaNs. Rather - // than using floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the - // obvious order, -0 is ordered before 0, and -NaN and NaN appear at - // the beginning and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; - } - // If key2 < key1 - auto is_smaller_than = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - if (iota_values_parameter_index >= 0) { - auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); - auto key_index1 = - read_element(iota_values_parameter_index, current_keys_index); - auto key_index2 = - read_element(iota_values_parameter_index, compare_keys_index); - auto index_is_smaller_than = - b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); - is_smaller_than = b->CreateOr( - is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() { + std::vector values_to_compare; + for (int i = 0; i < num_values; ++i) { + values_to_compare.push_back(element_address(i, compare_keys_index)); + values_to_compare.push_back(element_address(i, current_keys_index)); } + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer", + b); + TF_RETURN_IF_ERROR( + emit_compare_callback(values_to_compare, compare_return_buffer)); + llvm::Value* result = b->CreateLoad(compare_return_buffer); + + // Check if the 'compare' function returns true. + llvm::Value* is_smaller_than = + b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0), + "boolean_predicate"); ksl.If("is_smaller_than", is_smaller_than, [&]() { - // Swap key1 with key2. - write_element(0, current_keys_index, key2); - write_element(0, compare_keys_index, key1); - for (int64 i = 1; i <= num_values; ++i) { - // Also swap the values. - auto value1 = read_element(i, current_keys_index); - auto value2 = read_element(i, compare_keys_index); - write_element(i, current_keys_index, value2); - write_element(i, compare_keys_index, value1); + for (int64 i = 0; i < num_values; ++i) { + // Swap the values. + auto value1 = b->CreateLoad(values_to_compare[i * 2]); + auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]); + write_element(i, current_keys_index, value1); + write_element(i, compare_keys_index, value2); } }); + return Status::OK(); }); } -void EmitTiledCompareLoop( +Status EmitTiledCompareLoop( const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, - int64 dimension_to_sort_bound, PrimitiveType keys_type, - absl::Span xor_masks, const std::vector& params, - const std::vector& param_shmem_buffers, - int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { + int64 dimension_to_sort_bound, absl::Span xor_masks, + const std::vector& params, + const std::vector& param_shmem_buffers, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback, + llvm::IRBuilder<>* b) { KernelSupportLibrary ksl(b); llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); @@ -200,7 +172,7 @@ void EmitTiledCompareLoop( [&]() { auto cache_index = b->CreateShl(thread_id, value_one); read_or_write(cache_index, current_keys_index); - // Increment to go the next index position. + // Increment to go to the next index position. current_keys_index = b->CreateAdd(current_keys_index, value_one); // Here we check whether the next index position is within bounds. ksl.If("inner_smaller_keys_index", @@ -230,10 +202,18 @@ void EmitTiledCompareLoop( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); // Now emit the bodies of the comparison loops. - auto read_element = [&](int64 operand, llvm::Value* index) { - return b->CreateLoad( + auto element_address = [&](int64 operand, llvm::Value* index) { + auto shared_memory_address = b->CreateGEP(param_shmem_buffers[operand], - {tiled_keys_index.GetConstantWithIndexType(0), index})); + {tiled_keys_index.GetConstantWithIndexType(0), index}); + auto ptr_type = shared_memory_address->getType(); + // We need a generic pointer with address space 0 instead of a pointer to + // shared memory (address space 3) so that we can pass it to the comparison + // computation. + return b->CreateAddrSpaceCast( + shared_memory_address, + llvm::PointerType::get(ptr_type->getPointerElementType(), + /*AddressSpace=*/0)); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { @@ -252,7 +232,7 @@ void EmitTiledCompareLoop( if (dimension_to_sort_bound % tile_size) { // Otherwise we need a bounds check for the last tile. The last tile has // size 'dimension_to_sort_bound' % 'tile_size'. - ksl.If( + TF_RETURN_IF_ERROR(ksl.IfWithStatus( "is_last_tile", b->CreateICmpUGE( b->CreateMul(tiled_keys_index[dimension_to_sort], @@ -260,24 +240,24 @@ void EmitTiledCompareLoop( tiled_keys_index.GetConstantWithIndexType( RoundDownToNearest(dimension_to_sort_bound, tile_size))), [&]() { - EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, - params.size() - 1, iota_values_parameter_index, - element_pair_index, xor_mask, - tiled_keys_index.GetType(), read_element, - write_element, b); + return EmitCompareLoopBody( + dimension_to_sort_bound % tile_size, params.size(), + element_pair_index, xor_mask, tiled_keys_index.GetType(), + element_address, write_element, emit_compare_callback, b); }, [&]() { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), - read_element, write_element, b, - /*needs_bounds_checks=*/false); - }); + return EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false); + })); } else { - EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, - iota_values_parameter_index, element_pair_index, - xor_mask, tiled_keys_index.GetType(), read_element, - write_element, b, /*needs_bounds_checks=*/false); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + tile_size, params.size(), element_pair_index, xor_mask, + tiled_keys_index.GetType(), element_address, write_element, + emit_compare_callback, b, + /*needs_bounds_checks=*/false)); } // Wait until all comparisons have happened. llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); @@ -301,17 +281,16 @@ void EmitTiledCompareLoop( // same location in shared memory because we have exactly tile_size / 2 many // threads, and the linear index calculated by ParallelLoopEmitter uses // linear_index = blockIdx.x * blockDim.x + threadIdx.x; + return Status::OK(); } } // namespace -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, - const int64 tile_size) { +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, const int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback) { // Iterate through the keys shape in physical order, but skip the dimension to // sort and make it the innermost loop which is the loop where the comparisons // happen. In the dimension to sort, if we use tiling, we iterate through it @@ -321,7 +300,7 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, // within those 64 elements and are therefore independent of the other // comparisons). - const Shape& keys_shape = keys_array.GetShape(); + const Shape& keys_shape = values_arrays[0].GetShape(); int64 rank = keys_shape.rank(); int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); std::vector dimensions_in_iteration_order(rank); @@ -338,18 +317,16 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), dimensions_in_iteration_order); - std::vector params(1, keys_array); - params.insert(params.end(), values_arrays.begin(), values_arrays.end()); // Allocate shared memory for the tiled compare loop. - std::vector param_shmem_buffers(params.size(), nullptr); + std::vector param_shmem_buffers(values_arrays.size(), nullptr); if (xor_masks.size() > 1) { llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); - for (int64 i = 0; i < params.size(); ++i) { - llvm::Type* tile_type = - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - params[i].GetShape().element_type(), module), - tile_size); + for (int64 i = 0; i < values_arrays.size(); ++i) { + llvm::Type* tile_type = llvm::ArrayType::get( + llvm_ir::PrimitiveTypeToIrType( + values_arrays[i].GetShape().element_type(), module), + tile_size); param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( module, tile_type, absl::StrCat(name, "_tile_param_", i)); } @@ -376,25 +353,24 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; } if (xor_masks.size() > 1) { - EmitTiledCompareLoop(keys_index, dimension_to_sort, - dimension_to_sort_bound, keys_shape.element_type(), - xor_masks, params, param_shmem_buffers, - iota_values_parameter_index, tile_size, b); + TF_RETURN_IF_ERROR(EmitTiledCompareLoop( + keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks, + values_arrays, param_shmem_buffers, tile_size, emit_compare_callback, + b)); } else { - auto read_element = [&](int64 operand, llvm::Value* index) { + auto element_address = [&](int64 operand, llvm::Value* index) { keys_index[dimension_to_sort] = index; - return params[operand].EmitReadArrayElement(keys_index, b); + return values_arrays[operand].EmitArrayElementAddress(keys_index, b); }; auto write_element = [&](int64 operand, llvm::Value* index, llvm::Value* value) { keys_index[dimension_to_sort] = index; - params[operand].EmitWriteArrayElement(keys_index, value, b); + values_arrays[operand].EmitWriteArrayElement(keys_index, value, b); }; - EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), - values_arrays.size(), iota_values_parameter_index, - tiles_index[rank - 1], xor_masks[0], - tiles_index.GetType(), read_element, write_element, - b); + TF_RETURN_IF_ERROR(EmitCompareLoopBody( + dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1], + xor_masks[0], tiles_index.GetType(), element_address, write_element, + emit_compare_callback, b)); } return Status::OK(); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 685f9383acba416f51681270e4037d56abb4b6ea..b9341a34d1f2203db6e02c3df5d607174b6d0f74 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -28,19 +28,18 @@ limitations under the License. namespace xla { namespace llvm_ir { +using EmitCallToNestedComputationCallback = + std::function, llvm::Value*)>; // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' -// dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' -// contains only powers of 2, or values 2^k - 1 (k > 0). If -// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand -// that is a iota and can be used to make the sorting stable. -Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const std::vector& values_arrays, - int64 iota_values_parameter_index, - absl::string_view name, - absl::Span xor_masks, llvm::IRBuilder<>* b, - const gpu::LaunchDimensions& launch_dimensions, - int64 num_iterations_in_sort_dim, int64 tile_size); +// dimension of each array in 'values_arrays'. All other dimensions are kept +// as-is. This implements the inner loop of BitonicSort. It is assumed that +// 'xor_masks' contains only powers of 2, or values 2^k - 1 (k > 0). +Status EmitSortInPlace( + int64 dimension_to_sort, const std::vector& values_arrays, + absl::string_view name, absl::Span xor_masks, + llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size, + const EmitCallToNestedComputationCallback& emit_compare_callback); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.cc b/tensorflow/compiler/xla/service/op_expander_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..02c9d4b387b112be39c204d35fe4fa1013ed064c --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +StatusOr OpExpanderPass::Run(HloModule* module) { + std::vector matching_instructions; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + absl::c_copy_if( + computation->instructions(), std::back_inserter(matching_instructions), + [&](HloInstruction* inst) { return InstructionMatchesPattern(inst); }); + } + + for (HloInstruction* inst : matching_instructions) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, + ExpandInstruction(inst)); + if (expanded_root == nullptr) { + continue; + } + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !matching_instructions.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/op_expander_pass.h b/tensorflow/compiler/xla/service/op_expander_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..276e3d70b8ecd8742e0b277698765063198fe872 --- /dev/null +++ b/tensorflow/compiler/xla/service/op_expander_pass.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass is an abstract superclass for passes that replace operations that +// match a pattern. It is intended to be subclassed, not used directly. +// +// This pass is useful for legalizing HLO instructions that a particular backend +// does not support into other HLO instructions. +class OpExpanderPass : public HloModulePass { + public: + StatusOr Run(HloModule* module) override; + + protected: + // Returns `true` if `instruction` should be expanded by this pass. + virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; + + // Returns a replacement for `instruction`, or nullptr if no replacement is + // neeeded (e.g. only the to_apply subcomputation of the instruction was + // modified). + virtual StatusOr ExpandInstruction( + HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OP_EXPANDER_PASS_H_ diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3ddcaae193ba266f35fa6f9922fe4f3a4970cdc5..3f4456c1bbf0f620609459256424b9cb30a04e13 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -534,6 +534,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * p.interior_padding(); + if (dimensions[i] < 0) { + return InvalidArgument("Padding result in negative size for dimension %d", + i); + } is_dynamic[i] = operand_shape.is_dynamic_dimension(i); } @@ -832,7 +836,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(larger_shape)); } if (small_is_dynamic != large_is_dynamic) { - if ((small_dimension_size == 1 && !small_is_dynamic) || + if (small_dimension_size == large_dimension_size || + (small_dimension_size == 1 && !small_is_dynamic) || (large_dimension_size == 1 && !large_is_dynamic)) { // Do nothing. It's OK when the size-1 dimension is not static. } else { @@ -1858,6 +1863,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } + if (ShapeUtil::IsZeroElementArray(in)) { + return in; + } Shape result = ShapeUtil::ChangeElementType(in, C64); result.set_dimensions(result.dimensions_size() - 1, fft_length[fft_rank - 1] / 2 + 1); @@ -1899,6 +1907,49 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, #undef RET_CHECK_RANK } +/* static */ StatusOr ShapeInference::InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options) { + if (a.rank() < 2) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s", + a.ToString()); + } + if (b.rank() != a.rank()) { + return InvalidArgument( + "Arguments to triangular solve must have equal rank; got %s and %s.", + b.ToString(), a.ToString()); + } + if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + return InvalidArgument( + "The two minor dimensions of 'a' must have equal size, got %s.", + a.ToString()); + } + if (a.dimensions(a.rank() - 1) != + b.dimensions(b.rank() - (options.left_side() ? 2 : 1))) { + return InvalidArgument( + "The shared dimension of 'a' and 'b' does not match, got shapes %s and " + "%s", + a.ToString(), b.ToString()); + } + absl::Span a_batch_dims(a.dimensions()); + absl::Span b_batch_dims(b.dimensions()); + a_batch_dims.remove_suffix(2); + b_batch_dims.remove_suffix(2); + if (a_batch_dims != b_batch_dims) { + return InvalidArgument( + "The leading batch dimensions of the arguments to triangular solve " + "must be equal; got %s and %s.", + b.ToString(), a.ToString()); + } + if (!TriangularSolveOptions_Transpose_IsValid(options.transpose_a()) || + options.transpose_a() == TriangularSolveOptions::TRANSPOSE_INVALID) { + return InvalidArgument( + "Invalid transpose option value for triangular solve (%d).\n", + options.transpose_a()); + } + return b; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { @@ -2345,8 +2396,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (operand_shape.rank() != number_of_indices) { return InvalidArgument( - "Dynamic update slice start number of dimensions %d must match rank " - "%d of slice input (%s).", + "Dynamic update slice start number of dimensions %d must match " + "rank %d of slice input (%s).", number_of_indices, operand_shape.rank(), ShapeUtil::HumanString(operand_shape)); } @@ -2433,7 +2484,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg)); } - if (index >= arg.tuple_shapes_size()) { + if (index < 0 || index >= arg.tuple_shapes_size()) { return InvalidArgument( "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", @@ -2582,7 +2633,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, operand_shape.dimensions(i) != 1) { return InvalidArgument( "Input dimension should be either 1 or equal to the output dimension " - "it's broadcasting into; the %lldth operand dimension is %lld, the " + "it is broadcasting into; the %lldth operand dimension is %lld, the " "%lldth output dimension is %lld.", i, operand_shape.dimensions(i), broadcast_dimensions[i], output_shape.dimensions(broadcast_dimensions[i])); @@ -2650,11 +2701,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(operand.rank()); - std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != operand.rank() || - !std::is_permutation(dimensions.begin(), dimensions.end(), - indices.begin())) { + if (!IsPermutation(dimensions, operand.rank())) { return InvalidArgument( "Transpose dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 7d39ef38e05abf0a81683c1fb0f3999908b27d23..acb071ab18824472153fc608b812ad2d9c52651e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -116,6 +116,10 @@ class ShapeInference { static StatusOr InferFftShape(const Shape& in, FftType fft_type, absl::Span fft_length); + // Infers the shape produced by the given triangular solve operation. + static StatusOr InferTriangularSolveShape( + const Shape& a, const Shape& b, const TriangularSolveOptions& options); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 26120a06b823c9fddf378991cec434a880fb888d..f400ef51f07b006eef2ea674feff1dd72f836e77 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -252,7 +252,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const absl::Span& bcast) { + absl::Span bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -896,6 +896,20 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); } +TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); + auto inferredNegative_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, -1); + auto inferred2_status = + ShapeInference::InferGetTupleElementShape(tuple_shape, 2); + ASSERT_FALSE(inferredNegative_status.ok()); + ASSERT_FALSE(inferred2_status.ok()); + EXPECT_THAT(inferredNegative_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); + EXPECT_THAT(inferred2_status.status().error_message(), + HasSubstr("attempt to index out of tuple bounds")); +} + TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = ShapeInference::InferBinaryOpShape( @@ -1467,6 +1481,14 @@ TEST_F(ShapeInferenceTest, Pad) { Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); + + dimension1->set_edge_padding_low(-20); + dimension1->set_edge_padding_high(-10); + auto negative_dimension_size = ShapeInference::InferPadShape( + input_shape, padding_value_shape, padding_config); + ASSERT_FALSE(negative_dimension_size.ok()); + ASSERT_THAT(negative_dimension_size.status().error_message(), + HasSubstr("negative size for dimension 1")); } TEST_F(ShapeInferenceTest, Reverse) { @@ -1550,6 +1572,16 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Rank1Transpose) { + Shape a_shape = ShapeUtil::MakeShape(F32, {5}); + auto inferred_shape_and_status = + ShapeInference::InferTransposeShape(a_shape, {0}); + EXPECT_IS_OK(inferred_shape_and_status); + Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); + EXPECT_TRUE( + ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); +} + TEST_F(ShapeInferenceTest, Conditional) { auto inferred_status0 = ShapeInference::InferConditionalShape( pred_, vector_32_, vector_64_, diff --git a/tensorflow/compiler/xla/service/sort_simplifier.cc b/tensorflow/compiler/xla/service/sort_simplifier.cc index 4a00e8d7b227f14d462ca53f695189f3f48754ee..122366a0f322a66963b364e1b19629cbd2d9aabe 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier.cc @@ -14,12 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/sort_simplifier.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" + +#include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace { @@ -39,8 +42,7 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { return false; } - // Index 0 is the sorting key used by the sort HLO itself. - absl::flat_hash_set used_indices{0}; + absl::flat_hash_set used_indices; for (const HloInstruction* user : sort->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { // Can't analyse users other then get-tuple-element. @@ -49,15 +51,25 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { used_indices.insert(user->tuple_index()); } + // Also note which parameters are used by the comparator computation. + auto comparator = sort->to_apply(); + for (int64 i = 0; i < sort->operand_count() * 2; ++i) { + if (comparator->parameter_instruction(i)->user_count() > 0) { + // operand i corresponds to parameters 2 * i and 2 * i + 1 of the + // computation. + used_indices.insert(i / 2); + } + } + if (used_indices.size() == sort->operand_count()) { // All operands are used. return false; } - std::vector operands{sort->mutable_operand(0)}; - std::vector new_shapes{sort->operand(0)->shape()}; - for (int64 i = 1; i < sort->operand_count(); ++i) { - if (used_indices.count(i)) { + std::vector operands; + std::vector new_shapes; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (used_indices.contains(i)) { operands.push_back(sort->mutable_operand(i)); new_shapes.push_back(sort->operand(i)->shape()); } @@ -68,6 +80,32 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { : ShapeUtil::MakeTupleShape(new_shapes); HloInstruction* new_sort = computation->AddInstruction( sort->CloneWithNewOperands(new_sort_shape, operands)); + absl::flat_hash_map> + replacements; + int64 parameter_number = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto* old_lhs_parameter = comparator->parameter_instruction(i * 2); + auto* old_rhs_parameter = comparator->parameter_instruction(i * 2 + 1); + if (used_indices.contains(i)) { + Shape scalar_shape = + ShapeUtil::MakeShape(sort->operand(i)->shape().element_type(), {}); + replacements[old_lhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".lhs")); + ++parameter_number; + replacements[old_rhs_parameter] = HloInstruction::CreateParameter( + parameter_number, scalar_shape, + absl::StrCat("p.", parameter_number / 2, ".rhs")); + ++parameter_number; + } else { + replacements[old_lhs_parameter] = nullptr; + replacements[old_rhs_parameter] = nullptr; + } + } + HloModule* module = sort->GetModule(); + HloComputation* new_compare = module->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements))); + new_sort->set_to_apply(new_compare); // Map from original get-tuple-element tuple index to new HLO instruction absl::flat_hash_map result_map; @@ -83,7 +121,8 @@ StatusOr RemoveUnusedOperandFromSort(HloInstruction* sort) { } } } else { - result_map[0] = new_sort; + CHECK_EQ(used_indices.size(), 1); + result_map[*used_indices.begin()] = new_sort; } std::vector users(sort->users().begin(), sort->users().end()); diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc index cd05fcf830d32e8bac4f8b260d3dd143ab98ad7b..696ac1b465848894f8dcb1c88bc48c6a5b268ef4 100644 --- a/tensorflow/compiler/xla/service/sort_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc @@ -34,13 +34,21 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} parameter(1) - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), - dimensions={1} - ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); @@ -58,17 +66,27 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,87] parameter(0) - values.0 = s32[64,87] parameter(1) - values.1 = u32[64,87] parameter(2) - sort = (f32[64,87], s32[64,87], u32[64,87]) sort( - keys, values.0, values.1), - dimensions={1} - gte.0 = f32[64,87] get-tuple-element(sort), index=0 - gte.1 = u32[64,87] get-tuple-element(sort), index=2 - ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.2.lhs = u32[] parameter(4) + p.2.rhs = u32[] parameter(5) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,87] parameter(0) + values.0 = s32[64,87] parameter(1) + values.1 = u32[64,87] parameter(2) + sort = (f32[64,87], s32[64,87], u32[64,87]) sort( + keys, values.0, values.1), + dimensions={1}, to_apply=compare + gte.0 = f32[64,87] get-tuple-element(sort), index=0 + gte.1 = u32[64,87] get-tuple-element(sort), index=2 + ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1) + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); @@ -86,17 +104,57 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) { const char* hlo_string = R"( HloModule permutation_sort - ENTRY sort_computation { - keys = f32[64,8732]{1,0} parameter(0) - values = s32[64,8732]{1,0} parameter(1) - sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} - ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 - })"; + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); SortSimplifier simplifier; EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } + +TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare + ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + SortSimplifier simplifier; + uint64 num_executions = 0; + do { + num_executions++; + } while (simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(num_executions, 2); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Sort(m::Parameter(1)))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.cc b/tensorflow/compiler/xla/service/stable_sort_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..1aa7e5fe7c0d57ee3303480e4727c456727f64c8 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.cc @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Looks for a iota operand that can be used as tie breaker in the computation. +// If no matching iota operand is found, a iota operand is added to Sort. The +// comparison computation is adjusted to break ties using the values from the +// iota operand. +StatusOr StableSortExpander::ExpandInstruction( + HloInstruction* instruction) { + auto* sort = Cast(instruction); + HloComputation* computation = sort->parent(); + + HloInstruction* expanded_sort = nullptr; + absl::flat_hash_set used_indices; + int64 iota_index = -1; + for (const HloInstruction* operand : sort->operands()) { + // We can only use the iota operand if it has an iota dimension which is the + // same as the dimension to sort. Also it should have an integral type that + // is large enough for the number of elements in the sort dimension. For + // now, we only allow S32, because we expect to find a S32 iota operand for + // all Sort ops which are created by TopK. + // TODO(b/122298745): Also support other types. + if (operand->opcode() == HloOpcode::kIota && + Cast(operand)->iota_dimension() == + sort->sort_dimension() && + operand->shape().element_type() == S32) { + iota_index = sort->operand_index(operand); + break; + } + } + + // If there is currently no iota operand which we could use for making the + // sort stable, we will have to add a new such operand. + if (iota_index == -1) { + Shape iota_shape = sort->operand(0)->shape(); + // We might need to use S64 if the number of elements in the sort dimension + // is bigger than 2^31 - 1. + // TODO(b/122298745): Handle Sort ops where S32 is too small for the number + // of elements in the sort dimension. + if (iota_shape.dimensions(sort->sort_dimension()) > + std::numeric_limits::max()) { + return Unimplemented( + "Stable sorting of more than 2^31-1 elements is not implemented"); + } + iota_shape.set_element_type(S32); + auto iota = computation->AddInstruction( + HloInstruction::CreateIota(iota_shape, sort->sort_dimension())); + + // Create a new comparator. + auto comparator = sort->to_apply(); + absl::flat_hash_map> + replacements; + std::vector> extra_parameters; + std::vector extra_parameter_ptrs; + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".lhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + extra_parameters.push_back(HloInstruction::CreateParameter( + sort->operand_count() * 2 + 1, scalar_shape, + absl::StrCat("p.", sort->operand_count(), ".rhs"))); + extra_parameter_ptrs.push_back(extra_parameters.back().get()); + sort->set_to_apply(sort->GetModule()->AddEmbeddedComputation( + comparator->CloneWithReplacements(std::move(replacements), + extra_parameter_ptrs))); + + // Replace the original sort op. + std::vector new_operands(sort->operands().begin(), + sort->operands().end()); + new_operands.push_back(iota); + std::vector new_shapes = sort->operand_count() == 1 + ? std::vector{sort->shape()} + : sort->shape().tuple_shapes(); + new_shapes.push_back(iota_shape); + Shape new_sort_shape = ShapeUtil::MakeTupleShape(new_shapes); + HloInstruction* new_sort = computation->AddInstruction( + sort->CloneWithNewOperands(new_sort_shape, new_operands)); + + // Add a "wrapper" around the new sort op to make sure we have the same + // shape as before. For the rank 1 case, we only need a GetTupleElement, + // otherwise we create a Tuple consisting of GetTupleElements of the new + // sort. + std::vector tuple_elements; + tuple_elements.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + sort->operand(i)->shape(), new_sort, i))); + } + expanded_sort = tuple_elements[0]; + if (tuple_elements.size() > 1) { + expanded_sort = computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + } + sort = Cast(new_sort); + iota_index = sort->operand_count() - 1; + } + + // Modify the computation to break ties using the iota operand. + auto comparator = sort->to_apply(); + std::vector instructions_postorder = + comparator->MakeInstructionPostOrder(); + absl::flat_hash_map replacements; + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second; + }; + HloInstruction* old_root = comparator->root_instruction(); + // The comparison computation gets 2 * n parameters (n being the number of + // operands of Sort), where parameters 2 * i and 2 * i + 1 correspond to two + // different scalars of operand i of Sort which are to be compared. The + // comparison computation should induce a strict weak order, so if + // to_apply(p1.lhs, p1.rhs, ..., pn.lhs, pn.rhs) is equal to + // to_apply(p1.rhs, p1.lhs, ..., pn.rhs, pn.lhs), we can conclude that the + // values to be compared are equivalent, and perform a tie-breaker comparison. + // + // We clone each instruction with at least one operand, but use as new + // operands of the instruction the replacements of the original operands. + // Parameter 2 * i is replaced by parameter 2 * i + 1 and vice versa. This + // should make sure that the cloned root instruction gives the result of the + // comparison computation when being called with each scalar pair reversed. + // parameters corresponding to the iota operand. + for (int64 i = 0; i < comparator->num_parameters(); ++i) { + replacements[comparator->parameter_instruction(i)] = + comparator->parameter_instruction(i ^ 1); + } + HloInstruction* cloned_root = nullptr; + for (HloInstruction* inst : instructions_postorder) { + if (inst->operand_count() == 0) { + continue; + } + std::vector new_operands; + new_operands.reserve(inst->operand_count()); + for (HloInstruction* operand : inst->operands()) { + new_operands.push_back(replace(operand)); + } + auto new_instruction = + inst->CloneWithNewOperands(inst->shape(), new_operands); + replacements[inst] = new_instruction.get(); + if (inst == old_root) { + cloned_root = new_instruction.get(); + } + comparator->AddInstruction(std::move(new_instruction)); + } + CHECK_NE(cloned_root, nullptr); + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + HloInstruction* same = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kEq, old_root, cloned_root)); + HloInstruction* tie_breaker = + comparator->AddInstruction(HloInstruction::CreateBinary( + scalar_pred, HloOpcode::kLt, + comparator->parameter_instruction(2 * iota_index), + comparator->parameter_instruction(2 * iota_index + 1))); + HloInstruction* new_root = + comparator->AddInstruction(HloInstruction::CreateTernary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker, + old_root)); + comparator->set_root_instruction(new_root); + + return expanded_sort; +} + +bool StableSortExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSort && + Cast(instruction)->is_stable(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stable_sort_expander.h b/tensorflow/compiler/xla/service/stable_sort_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..31b6fd92d25370218017c58072f1aa5e64df00c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/stable_sort_expander_test.cc b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a62d953e6e8fa2f3c1ecfd9e4a7900eee74f9dca --- /dev/null +++ b/tensorflow/compiler/xla/service/stable_sort_expander_test.cc @@ -0,0 +1,358 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace m = match; + +using StableSortExpanderTest = HloTestBase; + +// Checks whether 'a' and 'b' are roots of equivalent computations, except that +// parameters 2 * i and 2 * i + 1 are switched. +bool IsSameComputationExceptParams(const HloInstruction* a, + const HloInstruction* b) { + if (a->opcode() != b->opcode() || a->operand_count() != b->operand_count()) { + return false; + } + if (a->opcode() == HloOpcode::kParameter) { + // Check that parameters were switched. + return a->parameter_number() == (b->parameter_number() ^ 1); + } + // If the operation has no operands, it should actually be the same. + if (a->operand_count() == 0) { + return a == b; + } + // Otherwise recursively compare all operands. + for (int64 i = 0; i < a->operand_count(); ++i) { + if (!IsSameComputationExceptParams(a->operand(i), b->operand(i))) { + return false; + } + } + return true; +} + +// Check that the comparison computation has been modified to add a tie breaker +// using 'iota_parameter'. +void CheckComputationHasTieBreaker(const HloInstruction* root, + int64 iota_parameter) { + // With the tie breaker, the root instruction should be + // Select(Eq(Comp(), CompReverse()), Lt(), Comp()) + // with Comp() being the original comparison function, and CompReverse() being + // the copied comparison function where the parameters are reversed. Lt() is + // the tie breaker comparison using the Iota operand. + ASSERT_EQ(root->opcode(), HloOpcode::kSelect); + ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq); + + // Check that the tie breaker instruction is correct. + EXPECT_THAT(root->operand(1), + GmockMatch(m::Lt(m::Parameter(iota_parameter * 2), + m::Parameter(iota_parameter * 2 + 1)))); + EXPECT_EQ(root->operand(2), root->operand(0)->operand(0)); + + // Check that Comp() and CompReverse() are equivalent except that + // CompReverse() has reversed parameters. + EXPECT_TRUE(IsSameComputationExceptParams(root->operand(0)->operand(0), + root->operand(0)->operand(1))); +} + +TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortReuseIotaOperandComplicatedComparison) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + max = u32[] constant(2147483647) + zero = s32[] constant(0) + lhs.signed = s32[] bitcast-convert(p.0.lhs) + lhs.unsigned = u32[] bitcast-convert(p.0.lhs) + lhs.flipped = u32[] subtract(max, lhs.unsigned) + lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped) + lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero) + lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed) + rhs.signed = s32[] bitcast-convert(p.0.rhs) + rhs.unsigned = u32[] bitcast-convert(p.0.rhs) + rhs.flipped = u32[] subtract(max, rhs.unsigned) + rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped) + rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero) + rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed) + ROOT lt = pred[] less-than(lhs.converted, rhs.converted) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} parameter(1) + ROOT sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Tuple( + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 0), + m::GetTupleElement( + m::Sort(m::Parameter(0), m::Parameter(1), m::Iota()), 1)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, HonorIsStableFlag) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=false + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_FALSE(stabilizer.Run(module.get()).ValueOrDie()); +} + +TEST_F(StableSortExpanderTest, + StabilizeSortDontReuseIotaOperandWrongDimension) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=0 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) + } + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), + dimensions={1}, to_apply=compare, is_stable=true + ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + // Simplify away the "wrapper" tuple around the new sort. + AlgebraicSimplifier simplifier(AlgebraicSimplifierOptions( + [](const Shape&, const Shape&) { return false; })); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota(), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/2); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + ROOT sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0))); + CheckComputationHasTieBreaker( + root->operand(0)->to_apply()->root_instruction(), /*iota_parameter=*/1); +} + +TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) { + const char* hlo_string = R"( + HloModule permutation_sort + + compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + mask = s32[] constant(65535) + lhs = s32[] and(p.0.lhs, mask) + rhs = s32[] and(p.0.rhs, mask) + ROOT lt = pred[] less-than(lhs, rhs) + } + + ENTRY sort_computation { + keys = s32[64,8732]{1,0} parameter(0) + sort = s32[64,8732]{1,0} sort(keys), dimensions={0}, to_apply=compare, + is_stable=true + ROOT neg = s32[64,8732]{1,0} negate(sort) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + StableSortExpander stabilizer; + EXPECT_TRUE(stabilizer.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Negate(m::GetTupleElement( + m::Sort(m::Parameter(0), m::Iota()), 0)))); + CheckComputationHasTieBreaker( + root->operand(0)->operand(0)->to_apply()->root_instruction(), + /*iota_parameter=*/1); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc similarity index 82% rename from tensorflow/compiler/xla/client/lib/triangular_solve.cc rename to tensorflow/compiler/xla/service/triangular_solve_expander.cc index ba7fde118fde990fbb4aa9a34dd0f0e67ff5a93b..b26cdc1db59b30d82b9ac58a8a2ac762220086be 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include #include @@ -33,6 +33,8 @@ limitations under the License. namespace xla { +namespace { + // Get the diagonal blocks of the coefficient matrix XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { XlaBuilder* builder = a.builder(); @@ -345,9 +347,10 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, }); } -XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool transpose_a, bool conjugate_a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + bool unit_diagonal, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -406,6 +409,20 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, return b; } + // TODO(phawkins): consider pushing triangle masking into + // InvertDiagonalBlocks. + if (unit_diagonal) { + // Mask everything but the subdiagonal/superdiagonal elements. + a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a)) + : Select(TriangleMask(a, 0), ZerosLike(a), a); + int64 k = ShapeUtil::GetDimension(a_shape, -1); + a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k), + /*broadcast_dimensions=*/{ndims - 2, ndims - 1}); + } else { + // Mask off the ignored elements of the triangular matrix a. + a = Triangle(a, lower); + } + // We find the diagonal blocks of the coefficient matrix auto diag_blocks = DiagonalBlocks(a, block_size); @@ -413,11 +430,6 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a, precision); - // Mask off the ignored elements of the triangular matrix a. - // TODO(phawkins): it would probably be preferable to perform this masking - // block by block inside SolveWithInvertedDiagonalBlocks. - a = Triangle(a, lower); - // We now find the solution using GEMMs auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, @@ -427,4 +439,66 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } +} // namespace + +bool TriangularSolveExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kTriangularSolve; +} + +StatusOr TriangularSolveExpander::ExpandInstruction( + HloInstruction* instruction) { + const TriangularSolveOptions& options = + instruction->triangular_solve_options(); + const string name = absl::StrFormat( + "xla.triangular_solve_%s_%s_%s_%s_%s_%s", + instruction->operand(0)->shape().ToString(), + instruction->operand(1)->shape().ToString(), + options.left_side() ? "left" : "right", + options.lower() ? "lower" : "upper", + TriangularSolveOptions_Transpose_Name(options.transpose_a()), + options.unit_diagonal() ? "unit" : "nonunit"); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // We do something unusual here: we build the computation using the + // XlaBuilder API, which is nominally an XLA client API. We do this because + // the external APIs for building complicated computations (XlaBuilder) + // are much more ergonomic than the internal ones. As it turns out, + // XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b"); + bool transpose_a = + options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE; + bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT; + + BuildTriangularSolve(a, b, options.left_side(), options.lower(), + transpose_a, conjugate_a, options.unit_diagonal(), + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST); + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..be2374ef8c86254d8db5ac1acac385aa0de7d3a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class TriangularSolveExpander : public OpExpanderPass { + public: + absl::string_view name() const override { + return "triangular_solve_expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 5e505aaf02f157d0cba9dff42b1a9b89a6691504..cc82e9bb0287b5a586fb21fee35d3124a6d6f121 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -699,6 +699,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. // (5) The 'user' of 'operand' is Sort, and it is the only user. +// (6) The 'user' of 'operand' is TriangularSolve, it is the second operand, +// and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -779,6 +781,14 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; } + if (user->opcode() == HloOpcode::kTriangularSolve) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 1; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index fd5759e44230db8223822d6ae0f511027f73d8f9..6f61fc44166298e86a88dfc4f0ce8526d65ffd02 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" @@ -1065,14 +1066,17 @@ TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto sort = - builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, MakeSortHlo(keys_shape, {keys}, 0, /*is_stable=*/false, + &builder, module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); EXPECT_TRUE( points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); @@ -1080,6 +1084,7 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto builder = HloComputation::Builder(TestName()); + module_ = CreateNewVerifiedModule(); Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); Shape values_shape = ShapeUtil::MakeShape(F32, {8}); @@ -1087,11 +1092,14 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { HloInstruction::CreateParameter(0, keys_shape, "keys")); auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); - auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, - {values})); + TF_ASSERT_OK_AND_ASSIGN( + auto* sort, + MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}), + {keys, values}, 0, /*is_stable=*/false, &builder, + module_.get())); - BuildModuleAndRunAnalysis(builder.Build()); + computation_ = module_->AddEntryComputation(builder.Build()); + RunAnalysis(); // The buffer for the keys can be shared with the first tuple entry. EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc index a36d3547a0987422c2658b0f3046f7b1f83369c6..94854047e530babe2234381a615aeb805f0d5933 100644 --- a/tensorflow/compiler/xla/shape.cc +++ b/tensorflow/compiler/xla/shape.cc @@ -34,8 +34,12 @@ Shape::Shape(const ShapeProto& shape_proto) { // instead of a constructor. if (shape_proto.dimensions_size() != shape_proto.is_dynamic_dimension_size()) { - LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " - "fields does not match number of dimension fields"; + if (shape_proto.is_dynamic_dimension_size() != 0) { + LOG(ERROR) << "Malformed shape proto: number of is_dynamic_dimension " + "fields does not match number of dimension fields"; + } else { + LOG(WARNING) << "Malformed shape proto: is_dynamic_dimension is empty"; + } } int64 num_dynamic_dimension_fields = std::min( shape_proto.dimensions_size(), shape_proto.is_dynamic_dimension_size()); @@ -143,26 +147,15 @@ bool Shape::Equal::operator()(const Shape& lhs, const Shape& rhs) { return false; } if (LayoutUtil::IsDenseArray(lhs)) { - if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - - const auto& lhs_tiles = lhs.layout().tiles(); - const auto& rhs_tiles = rhs.layout().tiles(); - if (lhs_tiles.size() != rhs_tiles.size()) { - return false; + Layout::Equal equal; + if (ignore_tiles_in_layout_) { + equal.IgnoreTiles(); } - for (int64 i = 0; i < lhs_tiles.size(); i++) { - if (!absl::c_equal(lhs_tiles[i].dimensions(), - rhs_tiles[i].dimensions())) { - return false; - } + if (ignore_element_size_in_layout_) { + equal.IgnoreElementSize(); } - - if (lhs.layout().element_size_in_bits() != - rhs.layout().element_size_in_bits()) { + if (!equal(lhs.layout(), rhs.layout())) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } } diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 1d594904e0b9e6f1779674e75b41b7a597788bac..78cea83c6d71e5965f10cd3a917ffccabd630462 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -146,10 +146,10 @@ class Shape { // // Examples: // - // - Comparing two shapes ignoring they layout difference: + // - Comparing two shapes ignoring their layout difference: // Equal().IgnoreLayout()(shape1, shape2); // - // - Comparing two shapes ignoring they layout and element type difference: + // - Comparing two shapes ignoring their layout and element type difference: // Equal().IgnoreLayout().IgnoreElementType()(shape1, shape2); class Equal { public: @@ -161,6 +161,14 @@ class Shape { ignore_layout_ = true; return *this; } + Equal& IgnoreTilesInLayout() { + ignore_tiles_in_layout_ = true; + return *this; + } + Equal& IgnoreElementSizeInLayout() { + ignore_element_size_in_layout_ = true; + return *this; + } Equal& IgnoreElementType() { ignore_element_type_ = true; return *this; @@ -174,8 +182,10 @@ class Shape { return *this; } - public: + private: bool ignore_layout_ = false; + bool ignore_tiles_in_layout_ = false; + bool ignore_element_size_in_layout_ = false; bool ignore_element_type_ = false; bool ignore_fp_precision_ = false; bool ignore_dynamic_dimension_ = false; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 1ada4bc0362f86bc770d4adfcd4d4b0ff7379c77..d045fc7a9e291258640eca75166e116cf7390a7b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -89,7 +90,8 @@ namespace { // its Layout. StatusOr MakeShapeWithLayoutInternal( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); @@ -100,11 +102,8 @@ StatusOr MakeShapeWithLayoutInternal( } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::MakeValidatedShape(element_type, dimensions)); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->clear(); - for (int64 value : minor_to_major) { - min2maj->push_back(value); - } + *shape.mutable_layout() = + LayoutUtil::MakeLayout(minor_to_major, tiles, element_size_in_bits); if (!shape.has_layout()) { return InvalidArgument("Shape has no layout."); } @@ -189,8 +188,10 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major) { - return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + absl::Span minor_to_major, absl::Span tiles, + int64 element_size_in_bits) { + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major, + tiles, element_size_in_bits) .ValueOrDie(); } @@ -1256,6 +1257,43 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& input_shape, const Shape& output_shape) { CHECK(input_shape.IsArray()); CHECK(output_shape.IsArray()); + // Removing trivial dimensions from the shape simplifies the alignment + // algorithm since ones can go in any position. + if (HasDegenerateDimensions(input_shape) || + HasDegenerateDimensions(output_shape)) { + auto simple_output_shape = + AlignLayouts(DropDegenerateDimensions(input_shape), + DropDegenerateDimensions(output_shape)); + if (!simple_output_shape) { + return absl::nullopt; + } + + auto layout = simple_output_shape->layout().minor_to_major(); + // For each one sized dimension in the output, increment the dimension + // numbers in layout that are more minor than the one. + absl::InlinedVector dim_map; + dim_map.reserve(simple_output_shape->rank()); + for (int64 i = 0; i < output_shape.rank(); ++i) { + if (output_shape.dimensions(i) != 1) { + dim_map.push_back(i); + } + } + for (int64& d : layout) { + d = dim_map[d]; + } + + // Add the ones in descending order to the layout. Descending layouts tend + // to reduce the number of copies inserted in layout assignment. + for (int64 i = output_shape.rank() - 1; i >= 0; --i) { + if (output_shape.dimensions(i) == 1) { + layout.push_back(i); + } + } + Shape output_shape_with_layout = output_shape; + *output_shape_with_layout.mutable_layout()->mutable_minor_to_major() = + layout; + return output_shape_with_layout; + } int64 input_rank = input_shape.rank(); int64 output_rank = output_shape.rank(); @@ -1304,10 +1342,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product != output_dimension_product) { return absl::nullopt; } + // We also need to store an end element so that we know where the last // alignment part ends. alignment.push_back({input_rank, output_rank}); - // Now check if the physical layout can potentially be aligned to the output // shape by changing the physical layout of the output shape. We need to check // that all dimension numbers that belong to the same alignment part appear @@ -1319,40 +1357,23 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 i = 0; i < input_rank;) { int64 current_dimension_number = input_dimension_numbers[i]; - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(current_dimension_number) == 1) { - ++i; - continue; - } - - // Calculate the number of non-trivial dimension bounds in the input shape - // belonging to the current alignment part. + // Trivial dimensions are stripped. + CHECK_NE(input_shape.dimensions(current_dimension_number), 1); const int64 current_alignment_index = dimension_to_alignment_index[current_dimension_number]; // Because of the special end element that we added, we can be sure that // 'current_alignment_index' is < alignment.size() - 1. CHECK_LT(current_alignment_index, alignment.size() - 1); - int64 num_non_trivial_dimensions_in_alignment_part = 0; - for (int64 j = alignment[current_alignment_index].first; - j < alignment[current_alignment_index + 1].first; ++j) { - if (input_shape.dimensions(j) != 1) { - ++num_non_trivial_dimensions_in_alignment_part; - } - } // Check that the following 'num_non_trivial_dimensions_in_alignment_part' // dimension numbers (ignoring dimension numbers with dimension bound 1) are // in descending order and belong to the current alignment part. - for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + for (int64 j = 0; j < alignment[current_alignment_index + 1].first - + alignment[current_alignment_index].first; ++i, ++j) { if (i == input_rank) { return absl::nullopt; } - // Skip trivial dimensions with a bound of 1. - if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { - --j; - continue; - } // If the current dimension number belongs to a different alignment part, // or the dimension numbers are not in descending order, we can return // early. @@ -1363,22 +1384,11 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } current_dimension_number = input_dimension_numbers[i]; } - // The output dimension numbers that belong to the current alignment part - // need to appear in the same descending order as in the input. Again, we - // can skip dimensions with a bound of 1. + // need to appear in the same descending order as in the input. for (int64 j = alignment[current_alignment_index + 1].second - 1; j >= alignment[current_alignment_index].second; --j) { - if (output_shape.dimensions(j) != 1) { - output_layout.push_back(j); - } - } - } - // Now add all the dimensions with dimension bound 1 at the end of - // 'output_layout'. - for (int64 i = 0; i < output_rank; ++i) { - if (output_shape.dimensions(i) == 1) { - output_layout.push_back(i); + output_layout.push_back(j); } } CHECK_EQ(output_layout.size(), output_rank); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index fb6da7460e2475732d6f02758e5519fbdb7c0f8d..7f610a6085d6fbe3d3143d5027cdc43d4b07bcbf 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -398,7 +398,9 @@ class ShapeUtil { // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout(PrimitiveType element_type, absl::Span dimensions, - absl::Span minor_to_major); + absl::Span minor_to_major, + absl::Span tiles = {}, + int64 element_size_in_bits = 0); static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, absl::Span dimensions, diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 126ae58293d12182e9b6e30f779f681829729526..020b062f6b1b032bab958772d3a6a1e35daee38b 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -761,8 +761,15 @@ TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { auto aligned_shape = ShapeUtil::AlignLayouts( input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); EXPECT_TRUE(aligned_shape); - EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), - ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithAllTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 1, 1, 1}, {0, 1, 3, 2}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 1, 1, 1, 1})); + EXPECT_TRUE(aligned_shape); EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); } diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index b88fe367d7416a26c1147fd5e10fb20772814fe5..aa7238f07d432aabb44d2cbed66786217e6a846c 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -25,6 +25,13 @@ limitations under the License. namespace xla { namespace status_macros { +ABSL_CONST_INIT const char kPossibleAutoJitAlternative[] = + "This error might be occurring with the use of xla.compile. If it is not " + "necessary that every Op be compiled with XLA, an alternative is to use " + "auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment " + "variable TF_XLA_FLAGS=\"tf_xla_auto_jit=2\" which will attempt to use xla " + "to compile as much of the graph as the compiler is able to."; + static Status MakeStatus(tensorflow::error::Code code, const string& message) { return Status(code, message); } diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index e51dd64e2a3dc7c359918cb08c6c94b2b4d9e91b..315136acc71670fa3ad48da4dc064e384ddadaa9 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -30,6 +30,10 @@ limitations under the License. namespace xla { namespace status_macros { +// This is a useful error message when encountering XLA Compiler errors that +// could be handled with the non-strict AutoJit mode. +extern const char kPossibleAutoJitAlternative[]; + // Stream object used to collect error messages in MAKE_ERROR macros // or append error messages with APPEND_ERROR. It accepts any // arguments with operator<< to build an error string, and then has an diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e8e779fb2a3f201ae056e6385eacfe6a63503749..562854756628df64fbf92d40af859f8b218b0cc2 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -317,6 +317,11 @@ xla_test( name = "conv_depthwise_backprop_filter_test", timeout = "long", srcs = ["conv_depthwise_backprop_filter_test.cc"], + # these backends do not natively handle batch group counts. + blacklisted_backends = [ + "gpu", + "cpu", + ], shard_count = 6, deps = [ "//tensorflow/compiler/xla:execution_options_util", @@ -676,7 +681,7 @@ xla_test( tags = [ "optonly", # This is a big test that we skip for capacity reasons in OSS testing. - "nooss", + "no_oss", ], deps = [ ":client_library_test_base", @@ -1141,7 +1146,7 @@ xla_test( xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], - shard_count = 40, + shard_count = 31, tags = [ "optonly", ], @@ -1389,8 +1394,8 @@ xla_test( ) xla_test( - name = "fmax_test", - srcs = ["fmax_test.cc"], + name = "fmax_fmin_test", + srcs = ["fmax_fmin_test.cc"], deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -2155,3 +2160,26 @@ xla_test( "//tensorflow/compiler/xla:test", ], ) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = [ + "enable_for_xla_interpreter", + "noasan", # sometimes times out, http://b/78650012 + ], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 7379fbcc22745f46f2a29732c4bda46f352d07e7..acdd3c9da92efe8fae1336eaa861c01d5bb9b158 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc index 96f4aedf8b996b152b77628252841348e732756f..dfbf0478e62713635446d11557367cfac6ab0dce 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -37,6 +37,8 @@ struct BatchGroupedConvolution2DSpec { std::vector activation_dims; std::vector kernel_dims; std::vector output_dims; + std::vector activation_and_kernel_layout; + std::vector output_layout; }; class BatchGroupedConvolution2DTest @@ -47,8 +49,9 @@ class BatchGroupedConvolution2DTest static std::vector GetConv2DTestCases() { std::vector config_set; std::vector> config_options = { - {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, {16, 20, 20, 256}, - {256, 7, 5, 4}, {256, 6, 6, 4}, {256, 8, 8, 512}}; + {8, 5, 3, 2}, {4, 5, 5, 2}, {8, 7, 4, 128}, + {16, 20, 20, 256}, {256, 7, 5, 4}, {256, 6, 6, 4}, + {256, 8, 8, 512}, {64, 7, 7, 960}, {64, 14, 14, 576}}; for (auto option : config_options) { int64 feature = option[3]; @@ -68,8 +71,14 @@ static std::vector GetConv2DTestCases() { int64 output_space_size = 3 + activation_size - kernel_size; config.output_dims = {output_space_size, output_space_size, feature, 1}; + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; config_set.push_back(config); + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); + // Add configurations for window dilation cases. if (activation_size % 2 == 0 && activation_size == kernel_size) { BatchGroupedConvolution2DSpec config; @@ -79,11 +88,17 @@ static std::vector GetConv2DTestCases() { config.activation_dims = {batch, activation_size, activation_size, feature}; config.kernel_dims = {batch, kernel_size / 2, kernel_size / 2, feature}; + config.activation_and_kernel_layout = {0, 3, 1, 2}; + config.output_layout = {2, 3, 0, 1}; int64 output_space_size = 5; config.output_dims = {output_space_size, output_space_size, feature, 1}; config_set.push_back(config); + + BatchGroupedConvolution2DSpec different_layout_config = config; + different_layout_config.activation_and_kernel_layout = {3, 0, 1, 2}; + config_set.push_back(different_layout_config); } } @@ -97,8 +112,11 @@ string BatchGroupedConvolution2DTestDataToString( const string data_type = GetFloatDataType(::testing::get<1>(data.param)); string str = absl::StrCat( "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), - "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_output_dims_", - absl::StrJoin(spec.output_dims, "x"), data_type); + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), + "_activation_layout_", + absl::StrJoin(spec.activation_and_kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), data_type, "_output_layout_", + absl::StrJoin(spec.output_layout, "_")); // Test names are not allowed to contain the '-' character. absl::c_replace(str, '-', 'n'); @@ -110,23 +128,28 @@ string BuildHloTextBatchGroupedConvolution2D( const string data_type = GetFloatDataType(use_bfloat16); return absl::StrFormat( R"( - HloModule TensorFlowDepthwiseConv + HloModule TensorFlowDepthwiseConv, is_scheduled=true ENTRY main { - activation = %s[%s] parameter(0) - kernel = %s[%s] parameter(1) - ROOT conv = %s[%s] convolution(%s[%s] activation, %s[%s] kernel), + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), window={size=%dx%d pad=1_%dx1_%d rhs_dilate=%dx%d}, dim_labels=f01b_i01o->01fb, batch_group_count=%d } )", - data_type, absl::StrJoin(spec.activation_dims, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), data_type, - absl::StrJoin(spec.output_dims, ","), data_type, - absl::StrJoin(spec.activation_dims, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), spec.window, spec.window, - spec.window_dilation, spec.window_dilation, spec.window_dilation, - spec.window_dilation, spec.output_batch); + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.activation_and_kernel_layout, ","), spec.window, + spec.window, spec.window_dilation, spec.window_dilation, + spec.window_dilation, spec.window_dilation, spec.output_batch); } XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { @@ -135,13 +158,13 @@ XLA_TEST_P(BatchGroupedConvolution2DTest, DoIt) { const string hlo_text = BuildHloTextBatchGroupedConvolution2D(spec, use_bfloat16); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, - [](HloModule* module) -> Status { - BFloat16MixedPrecisionRemoval remover; - TF_RETURN_IF_ERROR(remover.Run(module).status()); - Despecializer despecializer; - return despecializer.Run(module).status(); - })); + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index cad43d1b5547d74701760fa623e50466fc15c263..4687ed61a7de91bc1bce0efeadf1965ad7d52d55 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -172,8 +172,10 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { const Shape& r2f32_dim0_major = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); - b.AddInstruction(HloInstruction::CreateCustomCall( + auto custom_call = b.AddInstruction(HloInstruction::CreateCustomCall( r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + b.AddInstruction( + custom_call->CloneWithNewOperands(r2f32_dim0_major, {custom_call})); module->AddEntryComputation(b.Build()); ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); @@ -182,7 +184,7 @@ XLA_TEST_F(CustomCallTest, LayoutConstrained) { Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); Literal result = ExecuteAndTransfer(std::move(module), {&argument}); - LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); + LiteralTestUtil::ExpectR2Equal({{3.f, 4.f}, {5.f, 6.f}}, result); } XLA_TEST_F(CustomCallTest, TupleOutput) { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index f740f4815810727890583405b2244fceaec0bd3f..6ee2178a227a12b7baa933f036a44db8ec630a4c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1188,6 +1188,16 @@ std::vector GetEinsumTestCases() { p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"}, p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, p{v{6}, v{6, 7}, "b,bc->c"}, + p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"}, + p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"}, + p{v{77}, v{77}, "a,a->a"}, + p{v{77}, v{77, 55}, "a,ab->ba"}, + p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, + p{v{55}, v{}, "a,->a"}, + p{v{11, 111}, v{11}, "ab,a->ab"}, + p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, + p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"}, + p{v{5, 19}, v{}, "ab,->ab"}, }; return test_cases; } @@ -1257,5 +1267,82 @@ ENTRY %test { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(2) + rhs_1 = f32[1,40] parameter(1) + + dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + ROOT result = f32[20,1] divide(dot_0, dot_1) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) { + // Tests for a caching bug in the XLA CPU backend. + absl::string_view hlo_string = + R"( +HloModule CpuTiledDotEmitterCachingBug + +ENTRY main { + lhs_0 = f32[20,40] parameter(0) + rhs_0 = f32[40,1] parameter(1) + lhs_1 = f32[1,40] parameter(2) + rhs_1 = f32[20,40] parameter(3) + + dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + + dot_0_reshaped = f32[20] reshape(dot_0) + dot_1_reshaped = f32[20] reshape(dot_1) + + ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuIntegerDotCodegen)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[1,2,2] parameter(0) + arg1 = s32[1,2,1] parameter(1) + ROOT dot = s32[1,2,1] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(GpuTransposeOutput)) { + absl::string_view hlo_string = + R"( +HloModule TransposeOutput + +ENTRY TransposeOutput { + p0 = f32[32,32] parameter(0) + p1 = f32[32,64] parameter(1) + dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + ROOT tr = f32[64,32] transpose(dot), dimensions={1,0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/fmax_fmin_test.cc b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7423ac0bcdb0bc305ee384fb98bd17413404ecef --- /dev/null +++ b/tensorflow/compiler/xla/tests/fmax_fmin_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class FmaxSimpleTest : public ClientLibraryTestBase {}; + +TEST_F(FmaxSimpleTest, FmaxTenValues) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); + + std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FmaxEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Max(param0, param1); + std::vector expected = {INFINITY, INFINITY, NAN, NAN, INFINITY, + -5, NAN, INFINITY, 8, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +TEST_F(FmaxSimpleTest, FminEdgeCases) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + XlaOp param0, param1; + std::unique_ptr param0_data = CreateR1Parameter( + {INFINITY, INFINITY, INFINITY, -INFINITY, INFINITY, -INFINITY, NAN, + INFINITY, -INFINITY, NAN}, + /*parameter_number=*/0, /*name=*/"param0", + /*builder=*/&builder, /*data_handle=*/¶m0); + std::unique_ptr param1_data = CreateR1Parameter( + {INFINITY, -INFINITY, NAN, NAN, -4.0, -5.0, -6.0, 7.0, 8.0, 9.0}, + /*parameter_number=*/1, /*name=*/"param1", + /*builder=*/&builder, /*data_handle=*/¶m1); + + Min(param0, param1); + std::vector expected = {INFINITY, -INFINITY, NAN, NAN, -4, + -INFINITY, NAN, 7, -INFINITY, NAN}; + ComputeAndCompareR1(&builder, expected, + {param0_data.get(), param1_data.get()}, + ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc deleted file mode 100644 index c5bbbe778df15d63a2586bd6291a7a33fc82aa52..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class FmaxSimpleTest : public ClientLibraryTestBase {}; - -TEST_F(FmaxSimpleTest, FmaxTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1( - &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = ConstantR1( - &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - Max(x, y); - - std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, 9.0}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 66f72ba8d20b8ef1f436da4425b2bb6518ee9a94..0151981ef16aabe9e363bc4d7f9ba96d4a1f170f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -205,6 +205,17 @@ Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } +StatusOr> HloTestBase::ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + for (auto argument : arguments) { + options.arguments.push_back(argument); + } + return test_runner_.ExecuteReplicated(std::move(module), options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { @@ -313,7 +324,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string, + bool run_hlo_passes, + ExecutionProfile* profile, + string backend_config) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -321,19 +335,108 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( << "Error while parsing HLO text format: " << module_or_status.status().ToString(); } + + std::unique_ptr module = std::move(module_or_status.ValueOrDie()); const auto& fake_arguments = - MakeFakeArguments(module_or_status.ValueOrDie().get()) - .ConsumeValueOrDie(); + MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); - return test_runner_ - .Execute(std::move(module_or_status.ValueOrDie()), - fake_argument_ptrs, /*run_hlo_passes=*/true) - .ok() + + if (profile != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + // return ::testing::AssertionSuccess(); + auto output = test_runner_.Execute(std::move(module), fake_argument_ptrs, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); + + return output.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure(); + : ::testing::AssertionFailure() << output.status().error_message(); +} + +::testing::AssertionResult HloTestBase::RunMultipleTimes( + string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, string backend_config) { + int n = profiles->size(); + std::vector> fake_argument_ptrs(n); + std::vector> fake_arguments(n); + std::vector> executables(n); + + for (int i = 0; i < n; ++i) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + std::unique_ptr module = + std::move(module_or_status.ValueOrDie()); + + fake_arguments[i] = MakeFakeArguments(module.get()).ConsumeValueOrDie(); + absl::c_transform( + fake_arguments[i], std::back_inserter(fake_argument_ptrs[i]), + [](const Literal& literal) { return const_cast(&literal); }); + + if (profiles != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = module->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + module->set_config(config); + } + + if (!backend_config.empty()) { + // Set backend configuration if it is given. + HloInstruction* instruction = + module->entry_computation()->root_instruction(); + instruction->set_raw_backend_config_string(backend_config); + } + + auto executable = + test_runner_.CreateExecutable(std::move(module), run_hlo_passes); + if (!executable.ok()) { + return ::testing::AssertionFailure() + << executable.status().error_message(); + } + executables[i] = std::move(executable.ValueOrDie()); + } + + for (int i = 0; i < n; ++i) { + auto output = + test_runner_.Execute(std::move(executables[i]), fake_argument_ptrs[i], + /*profile=*/&((*profiles)[i])); + if (!output.ok()) { + return ::testing::AssertionFailure() << output.status().error_message(); + } + } + + return ::testing::AssertionSuccess(); } ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 69a4f96288c7285010e9adbdc33f1b394f58d8d2..3c2bcbb5df5ce94dd37f63d0c0e609f3ad2b60aa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -173,6 +173,11 @@ class HloTestBase : public ::testing::Test { Literal ExecuteAndTransfer(std::unique_ptr module, absl::Span arguments); + // Executes the given module on multiple replicas. + StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64 num_replicas); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. @@ -221,8 +226,14 @@ class HloTestBase : public ::testing::Test { const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const absl::string_view hlo_string) - TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const absl::string_view hlo_string, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr, + string backend_config = "") TF_MUST_USE_RESULT; + ::testing::AssertionResult RunMultipleTimes( + const absl::string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, + string backend_config = "") TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) diff --git a/tensorflow/compiler/xla/tests/plugin.bzl b/tensorflow/compiler/xla/tests/plugin.bzl index 8a5d91363b619c6b214a96ad96e92742e3052541..107869fe59d43d0a9a3e2b14af2c09e4906d9f15 100644 --- a/tensorflow/compiler/xla/tests/plugin.bzl +++ b/tensorflow/compiler/xla/tests/plugin.bzl @@ -33,4 +33,3 @@ # } plugins = {} - diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index f80d29b9de440b11c36e8c9bc65d4a93353a6267..e2cf4c0be289b52d5cc581ea07752ed6e98da76f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 95c89b0ba6f29c453abab88e29bca13ee006455a..67d2258928f75c078588c9425359f9468f4463ed 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -238,6 +238,79 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, return std::move(literal); } +template +void PopulateWithRandomIntegralDataWithBounds(Literal* literal, + std::minstd_rand0* engine, + IntT min, IntT max) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::uniform_int_distribution generator(min, max); + for (IntT& value : literal->data()) { + value = generator(*engine); + } +} + +// Same as MakeFakeLiteralInternal but generates random numbers in the given +// range [min, max]. Currently this works only for INT types. +StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, + std::minstd_rand0* engine, + int64 min, int64 max) { + if (shape.IsTuple()) { + std::vector elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max)); + elements.push_back(std::move(element)); + } + return LiteralUtil::MakeTupleOwned(std::move(elements)); + } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + Literal literal(shape); + switch (shape.element_type()) { + case S8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U8: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U16: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U32: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case S64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + case U64: + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, static_cast(min), static_cast(max)); + break; + default: + return Unimplemented( + "Unsupported type for fake random literal generation with bounds: %s", + ShapeUtil::HumanString(shape)); + } + return std::move(literal); +} + enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. @@ -297,6 +370,10 @@ std::vector FindConstrainedUses( if ((opcode == HloOpcode::kDynamicSlice && op_num >= 1) || (opcode == HloOpcode::kDynamicUpdateSlice && op_num >= 2)) { constrained_uses.push_back(instruction); + } else if ((opcode == HloOpcode::kGather || + opcode == HloOpcode::kScatter) && + op_num == 1) { + constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kFusion) { const HloInstruction* const to_analyze = instruction->fused_parameter(op_num); @@ -356,6 +433,22 @@ StatusOr CreateLiteralForConstrainedUses( } break; } + case HloOpcode::kGather: + case HloOpcode::kScatter: { + const Shape& operand_shape = use->operand(0)->shape(); + if (use->operand(1) == ¶m) { + auto index_map = + use->opcode() == HloOpcode::kGather + ? use->gather_dimension_numbers().start_index_map() + : use->scatter_dimension_numbers() + .scatter_dims_to_operand_dims(); + for (const auto dim_in_operand : index_map) { + index_bound = + std::min(index_bound, operand_shape.dimensions(dim_in_operand)); + } + } + break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = true; @@ -385,8 +478,8 @@ StatusOr CreateLiteralForConstrainedUses( return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { - return MakeRandomIndex(index_bound, engine) - .Reshape(param.shape().dimensions()); + return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1, + index_bound); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 591d6c19228a313f530cdae18f4be37e7b517601..f68ee04565f3898bd3db455e3e102bc2edb6255a 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -92,12 +92,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -122,12 +123,13 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 7); - EXPECT_EQ(args[0].Get({}), 0); + EXPECT_GE(args[0].Get({}), -1); + EXPECT_LE(args[0].Get({}), 1); - EXPECT_GE(args[1].Get({}), 0); - EXPECT_LE(args[0].Get({}), 2); + EXPECT_GE(args[1].Get({}), -1); + EXPECT_LE(args[1].Get({}), 2); - EXPECT_GE(args[2].Get({}), 0); + EXPECT_GE(args[2].Get({}), -1); EXPECT_LE(args[2].Get({}), 3); } @@ -136,10 +138,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { %parameter.0 = f32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -159,10 +169,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { auto module = ParseHloString(R"( HloModule sort.148.1589 +compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { %parameter.0 = s32[1048576]{0} parameter(0) %parameter.1 = s32[1048576]{0} parameter(1) - ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare } )") .ValueOrDie(); @@ -182,10 +200,18 @@ XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { auto module = ParseHloString(R"( HloModule sort, is_scheduled=true +compare { + p.0.lhs = bf16[] parameter(0) + p.0.rhs = bf16[] parameter(1) + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs) +} + ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { %parameter.0 = bf16[2,1452]{1,0} parameter(0) %parameter.1 = s32[2,1452]{1,0} parameter(1) - ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}, to_apply=compare } )") .ValueOrDie(); @@ -228,5 +254,77 @@ ENTRY %module (parameter.0: s32[], parameter.1: f32[20,20]) -> f32[] { << ShapeUtil::HumanString(args[1].shape()); } +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForGather) { + auto module = ParseHloString(R"( + HloModule Test + +ENTRY %module(paramater.0: f32[200,100,300], parameter.1: s32[10,2]) -> + f32[10,300] { + %parameter.0 = f32[200,100,300] parameter(0) + %parameter.1 = s32[10,2] parameter(1) + ROOT gather = f32[10,300] gather(f32[200,100,300] %parameter.0, + s32[10,2] %parameter.1), + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + slice_sizes={1,1,300} +} +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + +XLA_TEST_F(TestUtilsTest, MakeFakeArgumentsForScatter) { + auto module = ParseHloString(R"( + HloModule Test + +scatter_update (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + ROOT rhs = f32[] parameter(1) +} + +ENTRY main { + operand = f32[200,100,300] parameter(0) + indices = s32[10,2] parameter(1) + updates = f32[10,300] parameter(2) + ROOT scatter = f32[200,100,300] scatter(operand, indices, updates), + to_apply=scatter_update, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 + } +)") + .ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + + const Shape& indices_shape = args[1].shape(); + EXPECT_TRUE( + ShapeUtil::Equal(indices_shape, ShapeUtil::MakeShape(S32, {10, 2}))) + << ShapeUtil::HumanString(indices_shape); + auto indices = args[1].data(); + for (const auto index : indices) { + EXPECT_GE(index, -1); + EXPECT_LE(index, 100); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/tests/triangular_solve_test.cc similarity index 77% rename from tensorflow/compiler/xla/client/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/tests/triangular_solve_test.cc index 284a2e9d183a6a7923fb59ac134ce3b3a3a96e35..24ab12136ff396bd9ac37bb058311b0d2d6f2515 100644 --- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/tests/triangular_solve_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" - #include #include #include @@ -54,6 +52,20 @@ Array2D AValsUpper() { {kNan, kNan, kNan, 11}}; } +Array2D AValsLowerUnitDiagonal() { + return {{kNan, kNan, kNan, kNan}, + {3, kNan, kNan, kNan}, + {4, 7, kNan, kNan}, + {5, 8, 10, kNan}}; +} + +Array2D AValsUpperUnitDiagonal() { + return {{kNan, 3, 4, 5}, + {kNan, kNan, 7, 8}, + {kNan, kNan, kNan, 10}, + {kNan, kNan, kNan, kNan}}; +} + Array2D BValsRight() { return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; } @@ -96,8 +108,8 @@ XLA_TEST_F(TriangularSolveTest, EmptyArrays) { CreateR2Parameter(Array2D(0, 10), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); ComputeAndCompareR2(&builder, Array2D(0, 10), {a_data.get(), b_data.get()}); @@ -111,8 +123,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -132,8 +144,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -153,8 +165,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -174,8 +186,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -195,8 +207,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -217,8 +229,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -231,6 +243,25 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsLowerUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected( + {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { XlaBuilder builder(TestName()); @@ -239,8 +270,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/3); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -261,8 +292,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 1.0, 1.5}, @@ -283,8 +314,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -297,6 +328,27 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { ErrorSpec(1e-2, 1e-2)); } +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) { + XlaBuilder builder(TestName()); + + XlaOp a, b; + auto a_data = + CreateR2Parameter(AValsUpperUnitDiagonal(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/true, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE); + + Array2D expected({{-1402., -1538., -1674.}, + {575., 631., 687.}, + {-93., -102., -111.}, + {10., 11., 12.}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + ErrorSpec(1e-2, 1e-2)); +} + XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { XlaBuilder builder(TestName()); @@ -307,8 +359,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::ADJOINT); Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -333,8 +385,8 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); TriangularSolve(a, b, /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); Array2D expected({ {0.5, 1., 1.5}, @@ -368,11 +420,12 @@ XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) { XlaOp a, b; auto a_data = CreateR3Parameter(avals, 0, "a", &builder, &a); auto b_data = CreateR3Parameter(bvals, 1, "b", &builder, &b); - BatchDot(ConstantR3FromArray3D(&builder, avals), - TriangularSolve(a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2)); + BatchDot( + ConstantR3FromArray3D(&builder, avals), + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE)); ComputeAndCompareR3(&builder, bvals, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2)); @@ -382,7 +435,7 @@ struct TriangularSolveTestSpec { int m, n; // A is mxm, B is mxn bool left_side; bool lower; - bool transpose_a; + TriangularSolveOptions::Transpose transpose_a; }; class TriangularSolveParametricTest @@ -408,11 +461,11 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) { XlaOp a, b; auto a_data = CreateR2Parameter(avals, 0, "a", &builder, &a); auto b_data = CreateR2Parameter(bvals, 1, "b", &builder, &b); - auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a, - /*conjugate_a=*/false, - /*block_size=*/3); + auto x = TriangularSolve(a, b, spec.left_side, spec.lower, + /*unit_diagonal=*/false, spec.transpose_a); auto a_tri = Triangle(a, spec.lower); - a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a); + a_tri = MaybeTransposeInMinorDims( + a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE); if (spec.left_side) { BatchDot(a_tri, x); } else { @@ -429,7 +482,9 @@ std::vector TriangularSolveTests() { for (int n : {5, 10}) { for (bool left_side : {false, true}) { for (bool lower : {false, true}) { - for (bool transpose_a : {false, true}) { + for (TriangularSolveOptions::Transpose transpose_a : + {TriangularSolveOptions::NO_TRANSPOSE, + TriangularSolveOptions::TRANSPOSE}) { specs.push_back({m, n, left_side, lower, transpose_a}); } } diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 4fbd7f2fb174ac899c1e3b23801986cb52db96a2..c51f30f3b5db95962a719ec226dd03f41142a782 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -64,7 +64,9 @@ class UnaryOpTest : public ClientLibraryTestBase { &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); Sign(arg); - ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); + ComputeAndCompareR1( + &builder, + {-1, 1, static_cast(+0.0), static_cast(-0.0), -1, 1, -1}, {}); } template diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index c7337e8caae8f2ee25f4b25dc22439e08d2ecc25..7b7b8f5d02dc99607b30f898e18c5b448d421e07 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -40,8 +40,6 @@ limitations under the License. namespace xla { namespace { -namespace gtl = ::tensorflow::gtl; - class HloProfileTest : public ClientLibraryTestBase {}; struct ParsedProfileOutputLine { diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 7289ae7df65e56652eeeb67e536e4c721d97d999..fc7949d889dc8ed9fac425982cc555a6c42a7f1d 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 52fee4770ab940741723514d742e998b25765f24..ebd4bb1e42c9d1dc1f72a75514e916a2d900c30e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -177,26 +177,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "dumped_computation_to_tf_graphdef", - srcs = ["dumped_computation_to_tf_graphdef.cc"], - deps = [ - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_binary( name = "hlo_proto_to_json", srcs = ["hlo_proto_to_json.cc"], diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 4375e7c138c9e8d193feaa7a39d63946c4ea3086..df2d3d18b9ff86c0dd2047c2415527aeb1c1f154 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 723569862c7550387e95003e3a673743464b67b8..35bb82ca22f46d2cdeaac3b9a87b253efe9a07d9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc deleted file mode 100644 index f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* -// -// Dumps a tensorflow GraphDef in text format for a snapshot computation. The -// dumped graph is an HLO computation with HLO instructions as nodes and can be -// visualized on Tensorboard. Upload the dumped files on Tensorboard. -// -// some_binary_snapshot_proto is obtained by serializing the SessionModule from -// ServiceInterface::SnapshotComputation to disk. - -#include -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" - -using tensorflow::Env; - -namespace xla { -namespace tools { - -void RealMain(absl::Span args) { - Client* client = ClientLibrary::LocalClientOrDie(); - for (char* arg : args) { - HloSnapshot module; - TF_CHECK_OK( - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - XlaComputation computation = - client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = GetDebugOptionsFromFlags(); - debug_options.set_xla_generate_hlo_graph(".*"); - debug_options.set_xla_hlo_dump_as_graphdef(true); - ComputationStats stats = - client->GetComputationStats(computation, debug_options) - .ConsumeValueOrDie(); - fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); - } -} - -} // namespace tools -} // namespace xla - -int main(int argc, char** argv) { - std::vector flag_list; - xla::AppendDebugOptionsFlags(&flag_list); - xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_result) { - LOG(ERROR) << "\n" << usage; - return 2; - } - - tensorflow::port::InitMain(argv[0], &argc, &argv); - - absl::Span args(argv, argc); - args.remove_prefix(1); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); - return 0; -} diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index ac865707f8697e0b94173a2a33e7be52a9564867..0c7c078b9b9d30427cb01b8930bd012046d852d3 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -139,9 +139,10 @@ HloComputation* FindComputation(const HloModule& module, // Print a help message describing the various available commands. void DoHelpCommand() { std::cout << R"(Commands: - [] - Renders a neighborhood of nodes around . If - is not provided, the default value is )" + [] [/ +] + Renders a neighborhood of nodes around , without going + beyond the optional boundary instructions. If is not provided, + the default value is )" << kDefaultWidth << R"(. allpaths [] Renders a subset of all paths from one instruction to the other. Either @@ -457,12 +458,6 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module, // Plot a given instruction neighborhood or computation with graphviz. void DoPlotCommand(const Options& opts, const HloModule& module, const std::vector& tokens) { - if (tokens.size() > 2) { - std::cerr << R"(Illegal input. Enter e.g. "%fusion.1 42" or "%fusion.1".)" - << std::endl; - return; - } - string node_name = tokens[0]; // Find the node with the given name. @@ -475,16 +470,43 @@ void DoPlotCommand(const Options& opts, const HloModule& module, } uint64 graph_width = kDefaultWidth; - if (tokens.size() == 2) { + absl::flat_hash_set boundary; + if (tokens.size() >= 2) { if (comp) { std::cerr << "Can only use graph-size parameter with instructions, but " << node_name << " is a computation." << std::endl; return; } - if (!absl::SimpleAtoi(tokens[1], &graph_width)) { - std::cerr << "Can't parse '" << tokens[1] << "' as an integer." - << std::endl; - return; + + int bound_index = 1; + // Get the if present. + if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) { + bound_index++; + } else { + // not found, need to reset graph_width. + graph_width = kDefaultWidth; + } + // Get the '/'. + if (bound_index < tokens.size()) { + // This token must be a '/'. + if (tokens[bound_index] != "/") { + std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'." + << std::endl; + return; + } + bound_index++; + } + // Get the boundary nodes. + while (bound_index < tokens.size()) { + string bnode_name = tokens[bound_index]; + const HloInstruction* binstr = FindInstruction(module, bnode_name); + if (!binstr) { + std::cerr << "Couldn't find HloInstruction named " << bnode_name << "." + << std::endl; + return; + } + boundary.insert(binstr); + bound_index++; } } @@ -496,7 +518,9 @@ void DoPlotCommand(const Options& opts, const HloModule& module, /*show_backend_config=*/show_backend_config)); } else { DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( - *instr, graph_width, /*show_backend_config=*/show_backend_config)); + *instr, graph_width, + /*show_backend_config=*/show_backend_config, + /*boundary=*/boundary)); } } @@ -515,7 +539,7 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { << std::endl; continue; } - std::vector tokens = absl::StrSplit(line, ' '); + std::vector tokens = absl::StrSplit(line, ' ', absl::SkipEmpty()); if (tokens[0] == "quit" || tokens[0] == "exit") { break; } else if (tokens[0] == "help") { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 21217c23f6561a509cb3e30bf3dc841f8dc5db87..d66561315b4ad7a5e3f1f7b1bc1e557b71da6705 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -102,8 +102,9 @@ StatusOr> CompileExecutable( argument_layouts.push_back(Shape(param)); argument_layout_ptrs.push_back(&argument_layouts.back()); } - return client->Compile(computation, argument_layout_ptrs, - ExecutableBuildOptions()); + ExecutableBuildOptions exec_build_options; + *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags(); + return client->Compile(computation, argument_layout_ptrs, exec_build_options); } absl::optional GetXfeedShape(bool is_infeed, @@ -328,7 +329,10 @@ StatusOr ParseInputFile(const string& filename, fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); string contents; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); - StatusOr> module = ParseHloString(contents); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsFromFlags()); + StatusOr> module = + ParseHloString(contents, config); if (module.ok()) { *snapshot.mutable_hlo()->mutable_hlo_module() = module.ValueOrDie()->ToProto(); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index cdf306dfd1027cf6022c5d8ae844b4308f580e8d..b80d0db8d812380d8144713109d1c05168713c77 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 34b73b5206fa20d6dff7567afd78fd89897c8c33..bb8bbf57c4252b16836553334901a3c896a17f39 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -80,13 +81,9 @@ bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } - std::vector output(permutation.size(), -1); - for (auto index : permutation) { - CHECK_GE(index, 0); - CHECK_LT(index, rank); - output[index] = 0; - } - return !absl::c_linear_search(output, -1); + absl::InlinedVector trivial_permutation(rank); + absl::c_iota(trivial_permutation, 0); + return absl::c_is_permutation(permutation, trivial_permutation); } std::vector InversePermutation( diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 60adea5a4a242e5843b41927ba77c197e8fac444..cda2d7c7c6b2403868f6d01a485753fa29a8d95f 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,7 +17,9 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = cc_proto_library( name = name, srcs = srcs, - deps = deps, + # Append well-known proto dep. As far as I know this is the only way + # for xla_proto_library to access google.protobuf.{Any,Duration,...}. + deps = deps + ["@protobuf_archive//:cc_wkt_protos"], cc_libs = if_static( ["@protobuf_archive//:protobuf"], otherwise = ["@protobuf_archive//:protobuf_headers"], @@ -28,6 +30,11 @@ def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = **kwargs ) +def xla_py_proto_library(**kwargs): + # Note: we don't currently define a proto library target for Python in OSS. + _ignore = kwargs + pass + def xla_py_grpc_library(**kwargs): # Note: we don't currently define any special targets for Python GRPC in OSS. _ignore = kwargs diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 92834dbb02cdcd6383ceec3ffd079834b163ee6a..925fcbf88c1e8dd81ab1339d292e05eae52e0d13 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -15,11 +15,11 @@ limitations under the License. syntax = "proto3"; -import "tensorflow/compiler/xla/xla_data.proto"; -import "tensorflow/compiler/xla/service/hlo.proto"; - package xla; +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { // Where and when the reduce-precision operations will be added. @@ -72,8 +72,7 @@ message DebugOptions { // Path to dump HLO graphs to. string xla_hlo_graph_path = 4; - // Dump HLO graphs as TensorFlow GraphDefs. - bool xla_hlo_dump_as_graphdef = 5; + reserved 5; // Was xla_hlo_dump_as_graphdef // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to // dump *all* HLO modules. @@ -171,9 +170,7 @@ message DebugOptions { // HLO graph. bool xla_hlo_graph_sharding_color = 92; - // Prefix the name scopes of the TF graph exports with "devX" device - // assignments, if available. - bool xla_hlo_tfgraph_device_scopes = 93; + reserved 93; // Was xla_hlo_tfgraph_device_scopes // If true, the GPU backend is free to use cudnn for HLO batch normalization // ops. @@ -234,7 +231,23 @@ message DebugOptions { // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. bool xla_allow_scalar_index_dynamic_ops = 107; - // Next id: 108 + enum StepMarkerLocation { + // Generate step mark at each iteration of top level while loop, which + // is assumed to be a training loop. This is the default. + STEP_MARK_AT_ENTRY = 0; + // Generate step mark at program entry. This handles the case where each + // step are done by one or multiple programs execution. Only the first + // program will be tagged for generating step mark at program entry. + STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; + // No step mark. + STEP_MARK_NONE = 2; + } + // Option to emit a target-specific marker to indicate the start of a training + // step. The location of the marker (if any) is determined by the option + // value. + StepMarkerLocation xla_step_marker_location = 108; + + // Next id: 109 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -306,8 +319,7 @@ message TransferToInfeedRequest { DeviceHandle device_handle = 3; } -message TransferToInfeedResponse { -} +message TransferToInfeedResponse {} message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this @@ -326,8 +338,7 @@ message ResetDeviceRequest { DeviceHandle device_handle = 1; } -message ResetDeviceResponse { -} +message ResetDeviceResponse {} message ComputationGraphStatsRequest { HloModuleProto computation = 1; @@ -350,8 +361,7 @@ message UnregisterRequest { repeated GlobalDataHandle data = 1; } -message UnregisterResponse { -} +message UnregisterResponse {} message CompileRequest { // The graph to be compiled. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index a64e2f5df5cacca05e83f31c941c57abd5ccf4de..226299a7186ef0acb41f6d01fdeffeee06f13d4d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -545,6 +545,26 @@ enum RandomDistribution { // Next: 4 } +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + }; + Transpose transpose_a = 4; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -604,3 +624,15 @@ message PrecisionConfig { // Next: 2 } + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2dae746d034a1bf52e84de74dfb0c6e23aaed4d1..b2718c5c283358d98da175a8d3b21bb1f2b01c75 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -11,9 +11,15 @@ package( load( "//tensorflow:tensorflow.bzl", + "tf_custom_op_py_library", "tf_gen_op_libs", + "tf_gen_op_wrapper_py", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( name = "xrt_proto", @@ -27,6 +33,12 @@ xla_proto_library( ], ) +tf_proto_library_py( + name = "xrt_proto", # bzl adds a _py suffix + srcs = ["xrt.proto"], + visibility = ["//visibility:public"], +) + cc_library( name = "xrt_utils", srcs = [ @@ -78,6 +90,25 @@ tf_gen_op_libs( ], ) +tf_gen_op_wrapper_py( + name = "xrt_ops_wrapper_py", + out = "xrt_ops.py", + deps = [ + ":xrt_compile_ops_op_lib", + ":xrt_execute_op_op_lib", + ":xrt_state_ops_op_lib", + ], +) + +tf_custom_op_py_library( + name = "xrt_ops", + kernels = ["//tensorflow/compiler/xrt/kernels:xrt_ops"], + visibility = ["//visibility:public"], + deps = [ + ":xrt_ops_wrapper_py", + ], +) + cc_library( name = "xrt_server", visibility = ["//visibility:public"], diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index dc02fd272fd8700c7f8fa64adf7ab57c88bab706..1e325191bba828e3d5e4599f87dcf4f4d0674945 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -51,7 +51,10 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", + "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/compiler/xrt:xrt_state_ops_op_lib", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 2ee1a6cd1aebcdbd65892b33e5044489070ab5c4..b791519c09758a4f4124c95add5351a9433ecb8f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -68,9 +68,11 @@ class XRTCompileOp : public OpKernel { Status CompilationCacheKey(const xrt::XLAComputation& computation, string* key) { - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); + const size_t size = computation.ByteSizeLong(); + auto serialized = absl::make_unique(size); + TF_RET_CHECK( + SerializeToBufferDeterministic(computation, serialized.get(), size)); + uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size)); *key = absl::StrCat(fingerprint); return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 116c193cab65410a5a7c3058f98cc2be2cbe9e67..42ef88168af4b6f391ffc2e69ab4c4000d1cbee1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index 6a7f10652533920ba3fa48fba1d5161f7c4d4530..343f43b7159b55bad184eed2cada55c76085ffa0 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -122,6 +122,17 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_GPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_XLA_CPU) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); + REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") .Device(DEVICE_XLA_GPU) .HostMemory("handle"), diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index e2c223b3dbb2311d0f42e1a36e316fd9d5f66040..6af73ecc85351a9b38ba526db076e9176d1cb2f1 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -215,27 +217,29 @@ class XRTAllocateFromTensorOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); + std::vector minor_to_major; if (ctx->HasAttr("layouts")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); } OP_REQUIRES( ctx, tf_shapes_.size() == dtypes_.size(), errors::InvalidArgument("shapes and dtypes must be the same length")); std::vector xla_shapes; + xla_shapes.reserve(tf_shapes_.size()); for (int i = 0; i < tf_shapes_.size(); i++) { xla::Shape xla_shape; OP_REQUIRES_OK( ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); - xla_shapes.push_back(xla_shape); + xla_shapes.push_back(std::move(xla_shape)); } if (xla_shapes.size() > 1 || make_tuple) { shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); } else { shape_.Swap(&xla_shapes.front()); } - if (!minor_to_major_.empty()) { + if (!minor_to_major.empty()) { xla::Shape shape_with_layouts; - OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major_, + OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, /*layout_func=*/nullptr, &shape_with_layouts)); shape_.Swap(&shape_with_layouts); @@ -304,7 +308,6 @@ class XRTAllocateFromTensorOp : public OpKernel { private: std::vector tf_shapes_; DataTypeVector dtypes_; - std::vector minor_to_major_; xla::Shape shape_; }; @@ -487,7 +490,7 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - xla::Literal literal; + xla::Literal literal(allocation->on_host_shape()); OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); @@ -499,6 +502,96 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that reads a device-resident tuple to host memory and returns it as a +// literal. +template +class XRTReadToTensorOp : public OpKernel { + public: + explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + } + ~XRTReadToTensorOp() override = default; + XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; + XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTReadToTensorOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not + // just scalars.) + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + + if (discard_) { + VLOG(2) << "Releasing handle " << allocation_handle; + OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager( + rm, allocation_handle)); + } + + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + class DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + + xla::Shape shape = allocation->on_host_shape(); + int output = 0; + Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( + &shape, + [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { + if (subshape->IsTuple()) return Status::OK(); + + xla::PrimitiveType xla_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( + ctx->expected_output_dtype(output), &xla_type)); + if (xla_type != subshape->element_type()) { + return errors::InvalidArgument( + "Type mismatch between buffer type (", subshape->ToString(), + ") and tensor type (", + DataTypeString(ctx->expected_output_dtype(output)), + ") for output tensor ", output); + } + + TensorShape output_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); + + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_output(output, output_shape, &output_tensor)); + + XRTTupleAllocation* sub; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + allocation, index, &sub, /*alias_parent_allocation=*/true)); + core::ScopedUnref sub_unref(sub); + + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, + &literal)); + TF_RETURN_IF_ERROR(sub->ToLiteral( + device_ref.backend(), device_ref.device_ordinal(), &literal)); + + ++output; + return Status::OK(); + }); + OP_REQUIRES_OK(ctx, status); + } + bool discard_; + DataTypeVector dtypes_; +}; + // Op that writes a new literal value into device-resident memory. template class XRTWriteLiteralOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 2e743fec4963a52ee1abf64525f26e3d89479670..8832270fb2730d1ba64fa069b38f4a04b61773ef 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -151,6 +151,27 @@ releases the handle. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTReadToTensor") + .Input("handles: int64") + .Attr("release_handles: bool = False") + .Attr("dtypes: list(type)") + .Output("tensors: dtypes") + .SetShapeFn(tensorflow::shape_inference::UnknownShape) + .Doc( + R"( +Copies allocated values from device memory and returns them as zero or more +Tensors. If a handle refers to a non-tuple buffer, a single tensor is returned. +In general, the tensors returned for a handle correspond to an in-order traversal +of a the tuple-tree value referenced by the handle. + +'handles' contains ids returned from Ops that produced on-device allocations. +At present, only a single (scalar) handle is supported. +'dtypes' are the expected types for each `Tensor` to be returned. If the +expected and actual tensor types do not match, an error is returned. +'release_handles': if True, `handles` are released. +'tensors' are the output Tensors. +)"); + REGISTER_OP("XRTReleaseAllocationHandle") .Input("handle: int64") .SetShapeFn(tensorflow::shape_inference::NoOutputs) diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 1e2a9584f88b73d7c92a929e93af60376a59170b..1b3bcbea4c1228944a6604fc923228024e74d700 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/random/random.h" @@ -221,7 +220,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal) { + xla::MutableLiteralBase* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); @@ -235,9 +234,8 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, " has been released"); } } - TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( - stream.get(), shaped_buffer)); - return Status::OK(); + return transfer_manager->TransferLiteralFromDevice(stream.get(), + shaped_buffer, *literal); } Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index ddf2656e6f51775024a6d1cd0d7a387605faae6f..6519da30d02e41da5a862cadd2133bd8dd8b42d7 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -147,7 +147,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - xla::Literal* literal); + xla::MutableLiteralBase* literal); // Write a new literal value to the allocation. Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 25f2640e35af5f65eab25dc60c44e3ed7ce4e512..0173b8bb064c7b2fb8a0df018204515b24cfa718 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -218,7 +218,6 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", - "//tensorflow/contrib/tpu:all_ops", ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index f0b1c92cf7e4b760381da38febd9682ce2a4f27c..5608e7ddafa25757484d8c845c8c84a5691e143c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -73,8 +73,7 @@ cc_binary( "-z defs", "-s", "-Wl,--gc-sections", - "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - "$(location {})".format(LINKER_SCRIPT), + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 6138d7912601344ef7422fd50fb35c8401fd2e63..f0637595db08cbeb3b3ee0c94c5399df4c8c83e6 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { - namespace { class BigtableClientOp : public OpKernel { @@ -341,8 +340,8 @@ class ToBigtableOp : public AsyncOpKernel { } template - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { + Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { const Tensor* argument_t; TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); if (!TensorShapeUtils::IsScalar(argument_t->shape())) { @@ -360,5 +359,4 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), } // namespace } // namespace data - } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index e6fda9e61757f1441b3691c2a3d57c6f1a5a0d42..d9fce6e09f47ab05074f0b4c03dd8e672ed3d2ce 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -335,6 +335,17 @@ grpc::Status BigtableTestClient::ReadModifyWriteRow( return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "ReadModifyWriteRow not implemented."); } +std::unique_ptr> +BigtableTestClient::AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::unique_ptr< grpc::ClientReaderInterface> BigtableTestClient::ReadRows( diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index 8e1326f2ce841368ea81fc7194a0588e5d6cd637..63d59b32dd17a2f58d3413932b69f4d704c84e48 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -46,6 +46,13 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { google::bigtable::v2::ReadModifyWriteRowRequest const& request, google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + std::unique_ptr> + AsyncReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + grpc::CompletionQueue* cq) override; + std::unique_ptr< grpc::ClientReaderInterface> ReadRows(grpc::ClientContext* context, diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc index 416b719e30aa5f2504449d151a48e95c9105c68b..39c2a2e775d5d5287b137bf33eef66251738e6d3 100644 --- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -59,7 +59,7 @@ REGISTER_OP("BigtablePrefixKeyDataset") .Input("table: resource") .Input("prefix: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -68,14 +68,14 @@ REGISTER_OP("BigtableRangeKeyDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("BigtableSampleKeysDataset") .Input("table: resource") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -85,7 +85,7 @@ REGISTER_OP("BigtableSampleKeyPairsDataset") .Input("start_key: string") .Input("end_key: string") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); @@ -100,7 +100,7 @@ REGISTER_OP("BigtableScanDataset") .Input("columns: string") .Input("probability: float") .Output("handle: variant") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index a178820841c4c8bcb7f5742babdb6d0f4825de31..5ffbb9067081d7440ab5e11290697b822051bee5 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -84,12 +84,10 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -179,8 +177,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): `[batch_size, label_dimension]`). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. - label_name: String, name of the key in label dict. Can be null if label - is a tensor (single headed models). + label_name: String, name of the key in label dict. Can be null if label is + a tensor (single headed models). weight_column_name: Name of the column for weights, or None if not weighted. model_dir: Directory for model exports, etc. @@ -195,11 +193,11 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -286,11 +284,11 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -353,10 +351,9 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -376,12 +373,10 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -417,12 +412,12 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # When using this estimator, make sure to regularize the hessian (at least l2, # min_node_weight)! # TODO(nponomareva): extend to take multiple quantiles in one go. class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -449,8 +444,8 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -469,11 +464,11 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): opposed to contrib) version of tensorflow. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree override_global_step_value: If after the training is done, global step value must be reset to this value. This should be used to reset global step to a number > number of steps used to train the current ensemble. @@ -519,6 +514,7 @@ class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -534,10 +530,8 @@ def core_multiclass_head( def loss_fn(labels, logits): result = losses.per_example_maxent_loss( - labels=labels, - logits=logits, - weights=weight_column, - num_classes=n_classes) + # Don't pass the weights: head already multiplies by them. + labels=labels, logits=logits, weights=None, num_classes=n_classes) return result[0] # pylint:disable=protected-access @@ -564,7 +558,8 @@ def core_quantile_regression_head( result = losses.per_example_quantile_regression_loss( labels=labels, predictions=logits, - weights=weight_column, + # Don't pass the weights: head already multiplies by them. + weights=None, quantile=quantiles) return result[0] @@ -623,11 +618,11 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ @@ -685,10 +680,9 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. head: `Head` instance. - ranking_model_pair_keys: Keys to distinguish between features - for left and right part of the training pairs for ranking. For example, - for an Example with features "a.f1" and "b.f1", the keys would be - ("a", "b"). + ranking_model_pair_keys: Keys to distinguish between features for left and + right part of the training pairs for ranking. For example, for an + Example with features "a.f1" and "b.f1", the keys would be ("a", "b"). num_trees: An int, number of trees to build. feature_columns: A list of feature columns. weight_column_name: Name of the column for weights, or None if not @@ -703,12 +697,10 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is - [batch_size, num_trees]. - For example, - result_iter = classifier.predict(...) - for result_dict in result_iter: - # access leaf index list by result_dict["leaf_index"] - # which contains one leaf index per tree + [batch_size, num_trees]. For example, result_iter = + classifier.predict(...) + for result_dict in result_iter: # access leaf index list by + result_dict["leaf_index"] # which contains one leaf index per tree num_quantiles: Number of quantiles to build for numeric feature values. Raises: @@ -748,8 +740,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): # TODO(nponomareva): extend to take multiple quantiles in one go. class CoreGradientBoostedDecisionTreeQuantileRegressor( core_estimator.Estimator): - """An estimator that does quantile regression and returns quantile estimates. - """ + """An estimator that does quantile regression and returns quantile estimates.""" def __init__(self, learner_config, @@ -775,8 +766,8 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. quantiles: a list of quantiles for the loss, each between 0 and 1. - label_dimension: Dimension of regression label. This is the size - of the last dimension of the labels `Tensor` (typically, this has shape + label_dimension: Dimension of regression label. This is the size of the + last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). When label_dimension>1, it is recommended to use multiclass strategy diagonal hessian or full hessian. num_trees: An int, number of trees to build. @@ -795,11 +786,11 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( the bias. output_leaf_index: whether to output leaf indices along with predictions during inference. The leaf node indexes are available in predictions - dict by the key 'leaf_index'. For example, - result_dict = classifier.predict(...) - for example_prediction_result in result_dict: - # access leaf index list by example_prediction_result["leaf_index"] - # which contains one leaf index per tree + dict by the key 'leaf_index'. For example, result_dict = + classifier.predict(...) + for example_prediction_result in result_dict: # access leaf index list + by example_prediction_result["leaf_index"] # which contains one leaf + index per tree num_quantiles: Number of quantiles to build for numeric feature values. """ if len(quantiles) > 1: @@ -814,7 +805,9 @@ class CoreGradientBoostedDecisionTreeQuantileRegressor( params={ 'head': core_quantile_regression_head( - quantiles[0], label_dimension=label_dimension), + quantiles[0], + label_dimension=label_dimension, + weight_column=weight_column_name), 'feature_columns': feature_columns, 'learner_config': diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 47d910d42a27db4b857eeb12209dfbb429dd1be2..5a8b2ba9caf0a9813cb5b3409b8a0dc3de0a45d7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -399,8 +399,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def testQuantileRegression(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -413,7 +413,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, + num_trees=12, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -428,31 +428,12 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_below_upper >= 0.92) self.assertTrue(frac_below_upper <= 0.98) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - # Multi-dimensional quantile regression. def testQuantileRegressionMultiDimLabel(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.max_tree_depth = 6 + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -467,7 +448,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): quantiles=[0.95], learner_config=learner_config, label_dimension=2, - num_trees=100, + num_trees=18, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) @@ -490,35 +471,6 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_both_below_upper >= 0.91) self.assertTrue(frac_both_below_upper <= 0.99) - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - label_dimension=2, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.fit(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["scores"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) - class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -712,11 +664,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) - # One dimensional quantile regression. - def testQuantileRegression(self): + # Quantile regression in core is the same as in non core estimator, so we + # just check that it does not fail. + def testQuantileRegressionDoesNotThroughException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 + learner_config.constraints.max_tree_depth = 1 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE @@ -731,112 +684,12 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( quantiles=[0.95], learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_upper.train(input_fn=train_input_fn, steps=1000) - result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper >= 0.92) - self.assertTrue(frac_below_upper <= 0.98) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower >= 0.92) - self.assertTrue(frac_above_lower <= 0.98) - - # Multi-dimensional quantile regression. - def testQuantileRegressionMultiDimLabel(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 3 - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) - - train_input_fn, test_input_fn, y = _quantile_regression_input_fns( - two_dimension=True) - y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) - - # 95% percentile. - model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.95], - learner_config=learner_config, - num_trees=100, - label_dimension=2, + num_trees=1, examples_per_layer=_QUANTILE_REGRESSION_SIZE, center_bias=False) model_upper.train(input_fn=train_input_fn, steps=1000) result_iter = model_upper.predict(input_fn=test_input_fn) - upper = [] - for prediction_dict in result_iter: - upper.append(prediction_dict["predictions"]) - - count_below_upper = np.count_nonzero(upper > y, axis=0) - count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) - frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) - frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) - frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) - # +/- 3% - self.assertTrue(frac_below_upper_0 >= 0.92) - self.assertTrue(frac_below_upper_0 <= 0.98) - self.assertTrue(frac_below_upper_1 >= 0.92) - self.assertTrue(frac_below_upper_1 <= 0.98) - self.assertTrue(frac_both_below_upper >= 0.91) - self.assertTrue(frac_both_below_upper <= 0.99) - - train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( - two_dimension=True) - model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( - quantiles=[0.05], - learner_config=learner_config, - num_trees=100, - label_dimension=2, - examples_per_layer=_QUANTILE_REGRESSION_SIZE, - center_bias=False) - - model_lower.train(input_fn=train_input_fn, steps=1000) - result_iter = model_lower.predict(input_fn=test_input_fn) - lower = [] - for prediction_dict in result_iter: - lower.append(prediction_dict["predictions"]) - - count_above_lower = np.count_nonzero(lower < y, axis=0) - count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) - frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) - frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) - frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) - # +/- 3% - self.assertTrue(frac_above_lower_0 >= 0.92) - self.assertTrue(frac_above_lower_0 <= 0.98) - self.assertTrue(frac_above_lower_1 >= 0.92) - self.assertTrue(frac_above_lower_1 <= 0.98) - self.assertTrue(frac_both_above_lower >= 0.91) - self.assertTrue(frac_both_above_lower <= 0.99) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index a6e422847d3914188bca9e6dff797ba1ffb06749..eecf3c5aeb6c6785cae3fd5808954a73db6190d6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util @@ -88,6 +89,12 @@ def model_builder(features, if config is None: raise ValueError("Missing estimator RunConfig.") + if config.session_config is not None: + session_config = config.session_config + session_config.allow_soft_placement = True + else: + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + config = config.replace(session_config=session_config) center_bias = params["center_bias"] diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 6d78e27e8f69ea289b686af8402bd91967f997f4..65276242abaf96de8b1936365278b18f8bba93a9 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -538,7 +538,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { partition_boundaries[non_empty_partitions[root_idx]]; float best_gain = std::numeric_limits::lowest(); - int32 best_dimension_idx = 0; bool default_right = false; int32 best_element_idx = 0; @@ -571,7 +570,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; - const int32 dimension_id = dimension_and_start.dimension_id; int start_index = dimension_and_start.start_index; // Even for the last dimension, we always have additional dummy @@ -630,7 +628,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_left; best_element_idx = element_idx; default_right = false; - best_dimension_idx = dimension_id; } } // Consider calculating the default direction only when there were @@ -648,7 +645,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel { best_right_node_stats = right_stats_default_right; best_element_idx = element_idx; default_right = true; - best_dimension_idx = dimension_id; } } } diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index c3685b54e201f73039f6623443c67ba2b217a51e..ad6ff0a861af896ef0dd254bd47752d76378d63a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -33,7 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 0c319cc9bd1f720eb404a9da05227c5807ec874f..aff7105e94729942efc6e3e9d3ae23b733e8f5ed 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index ad1191d41236e71008bff8c8a7fbd42c16e3f9c5..2a0a206d97bbf01ac382531df31a66d429842bbb 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,7 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 99ed4959fad9699f265183d71a1f3b609d7e6d30..7b3df962542a656af8052e9f2eae6e83744411f2 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -27,7 +27,7 @@ Managing dependencies: @@NoDependency @@split_dependency -Checkpointable data structures: +Trackable data structures: @@List @@Mapping @@UniqueNameTracker @@ -49,17 +49,16 @@ from tensorflow.contrib.checkpoint.python.python_state import NumpyState from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint -from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.core.protobuf.trackable_object_graph_pb2 import TrackableObjectGraph as CheckpointableObjectGraph from tensorflow.python.training.checkpoint_management import CheckpointManager -from tensorflow.python.training.checkpointable.base import Checkpointable as CheckpointableBase -from tensorflow.python.training.checkpointable.data_structures import List -from tensorflow.python.training.checkpointable.data_structures import Mapping -from tensorflow.python.training.checkpointable.data_structures import NoDependency -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import capture_dependencies -from tensorflow.python.training.checkpointable.util import list_objects -from tensorflow.python.training.checkpointable.util import object_metadata - +from tensorflow.python.training.tracking.base import Trackable as CheckpointableBase +from tensorflow.python.training.tracking.data_structures import List +from tensorflow.python.training.tracking.data_structures import Mapping +from tensorflow.python.training.tracking.data_structures import NoDependency +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import capture_dependencies +from tensorflow.python.training.tracking.util import list_objects +from tensorflow.python.training.tracking.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 4e529322c7c76797938468b405cd175609dc0a73..cd9c94c9bd72d398d183d3f3d485ab48cb2fd617 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -12,7 +12,7 @@ py_library( ":python_state", ":split_dependency", ":visualize", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -22,8 +22,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:data_structures", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:data_structures", ], ) @@ -36,8 +36,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -47,7 +47,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", "//third_party/py/numpy", "@six_archive//:six", ], @@ -64,7 +64,7 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python:variables", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) @@ -76,7 +76,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) @@ -89,8 +89,8 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -101,8 +101,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python/training/checkpointable:base", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:base", + "//tensorflow/python/training/tracking:util", ], ) @@ -118,6 +118,7 @@ tf_py_test( "//tensorflow/python/eager:test", "//tensorflow/python/keras:engine", "//tensorflow/python/keras:layers", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], + tags = ["no_oss"], # b/124472244 ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 97936d9e9dfd5d6e62fdf8312707a276b63e1267..a25d51980ea760dfb7f323497a397fbd94fd5f23 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -1,4 +1,4 @@ -"""Checkpointable data structures.""" +"""Trackable data structures.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.tracking import base as trackable_lib +from tensorflow.python.training.tracking import data_structures -class UniqueNameTracker(data_structures.CheckpointableDataStructure): - """Adds dependencies on checkpointable objects with name hints. +class UniqueNameTracker(data_structures.TrackableDataStructure): + """Adds dependencies on trackable objects with name hints. Useful for creating dependencies with locally unique names. @@ -43,30 +43,30 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): def __init__(self): super(UniqueNameTracker, self).__init__() - self._maybe_initialize_checkpointable() + self._maybe_initialize_trackable() self._name_counts = {} @property def _values(self): return [dep.ref for dep in self._checkpoint_dependencies] - def track(self, checkpointable, base_name): - """Add a dependency on `checkpointable`. + def track(self, trackable, base_name): + """Add a dependency on `trackable`. Args: - checkpointable: An object to add a checkpoint dependency on. + trackable: An object to add a checkpoint dependency on. base_name: A name hint, which is uniquified to determine the dependency name. Returns: - `checkpointable`, for chaining. + `trackable`, for chaining. Raises: - ValueError: If `checkpointable` is not a checkpointable object. + ValueError: If `trackable` is not a trackable object. """ - if not isinstance(checkpointable, checkpointable_lib.Checkpointable): + if not isinstance(trackable, trackable_lib.Trackable): raise ValueError( - ("Expected a checkpointable value, got %s which does not inherit " - "from CheckpointableBase.") % (checkpointable,)) + ("Expected a trackable value, got %s which does not inherit " + "from tf.track.Trackable.") % (trackable,)) def _format_name(prefix, number): if number > 0: @@ -80,5 +80,5 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - self._track_value(checkpointable, name=candidate) - return checkpointable + self._track_value(trackable, name=candidate) + return trackable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index a2d453ec6eb3dcf9aba4c52fe866756a92673c63..bace21939602666aa48a05d2abfe05ae6aae41e2 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,9 +26,9 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import data_structures -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import data_structures +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util class UniqueNameTrackerTests(test.TestCase): @@ -52,7 +52,7 @@ class UniqueNameTrackerTests(test.TestCase): save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = tracking.AutoCheckpointable() + restore_slots = tracking.AutoTrackable() restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) @@ -68,7 +68,7 @@ class UniqueNameTrackerTests(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(tracking.AutoCheckpointable): + class SlotManager(tracking.AutoTrackable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 969c90c78871ebff02b360f8f09623df56c9c077..737a6c30c1dce65dd7638ee52e6c26a8a40f8321 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -23,7 +23,7 @@ import six import numpy -from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.tracking import base # pylint: disable=g-import-not-at-top try: @@ -34,8 +34,8 @@ except ImportError: # pylint: enable=g-import-not-at-top -class NumpyState(base.Checkpointable): - """A checkpointable object whose NumPy array attributes are saved/restored. +class NumpyState(base.Trackable): + """A trackable object whose NumPy array attributes are saved/restored. Example usage: @@ -72,7 +72,7 @@ class NumpyState(base.Checkpointable): """Create placeholder NumPy arrays for to-be-restored attributes. Typically `_lookup_dependency` is used to check by name whether a dependency - exists. We cheat slightly by creating a checkpointable object for `name` if + exists. We cheat slightly by creating a trackable object for `name` if we don't already have one, giving us attribute re-creation behavior when loading a checkpoint. @@ -85,7 +85,7 @@ class NumpyState(base.Checkpointable): value = super(NumpyState, self)._lookup_dependency(name) if value is None: value = _NumpyWrapper(numpy.array([])) - new_reference = base.CheckpointableReference(name=name, ref=value) + new_reference = base.TrackableReference(name=name, ref=value) self._unconditional_checkpoint_dependencies.append(new_reference) self._unconditional_dependency_names[name] = value super(NumpyState, self).__setattr__(name, value) @@ -101,7 +101,7 @@ class NumpyState(base.Checkpointable): def __setattr__(self, name, value): """Automatically wrap NumPy arrays assigned to attributes.""" # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making - # ndarrays checkpointable natively and using standard checkpointable list + # ndarrays trackable natively and using standard trackable list # tracking. if isinstance(value, (numpy.ndarray, numpy.generic)): try: @@ -110,19 +110,19 @@ class NumpyState(base.Checkpointable): return except AttributeError: value = _NumpyWrapper(value) - self._track_checkpointable(value, name=name, overwrite=True) + self._track_trackable(value, name=name, overwrite=True) elif (name not in ("_setattr_tracking", "_update_uid") and getattr(self, "_setattr_tracking", True)): - # Mixing restore()-created attributes with user-added checkpointable + # Mixing restore()-created attributes with user-added trackable # objects is tricky, since we can't use the `_lookup_dependency` trick to # re-create attributes (we might accidentally steal the restoration for - # another checkpointable object). For now `NumpyState` objects must be + # another trackable object). For now `NumpyState` objects must be # leaf nodes. Theoretically we could add some extra arguments to # `_lookup_dependency` to figure out whether we should create a NumPy # array for the attribute or not. raise NotImplementedError( ("Assigned %s to the %s property of %s, which is not a NumPy array. " - "Currently mixing NumPy arrays and other checkpointable objects is " + "Currently mixing NumPy arrays and other trackable objects is " "not supported. File a feature request if this limitation bothers " "you.") % (value, name, self)) @@ -130,7 +130,7 @@ class NumpyState(base.Checkpointable): @six.add_metaclass(abc.ABCMeta) -class PythonStateWrapper(base.Checkpointable): +class PythonStateWrapper(base.Trackable): """Wraps a Python object for storage in an object-based checkpoint.""" @abc.abstractmethod diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 45494351ff4e6c8c75634d8563c3fb63c6089036..40d8fe836402c8b6c8240ef9f665b753c54ede0d 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -26,7 +26,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import util class NumpyStateTests(test.TestCase): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3e9700ad74618e24843181d169f3fb39ac96bff6..d7b02b538909305b14e638761bd8ba67a948d2b4 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -21,7 +21,7 @@ import functools from tensorflow.python.ops import control_flow_ops from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(checkpointable.Checkpointable): +class _SplitDependency(trackable.Trackable): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, @@ -81,9 +81,9 @@ class _SplitDependency(checkpointable.Checkpointable): return control_flow_ops.no_op() def _gather_saveables_for_checkpoint(self): - """Looks to Checkpointable like a regular variable.""" + """Looks to Trackable like a regular variable.""" return { - checkpointable.VARIABLE_VALUE_KEY: + trackable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, save_callback=self._save, @@ -117,7 +117,7 @@ def split_dependency(component_names, component_dtypes, may return `None`). Returns: - A dictionary mapping from names to Checkpointable objects. If one is + A dictionary mapping from names to Trackable objects. If one is reachable from an object as a dependency, the others should be too; adding dependencies on some but not all of the objects will result in errors. """ diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 664a4e76ab31bf31c7a57924e4af866f2d746804..9bc01059481ff69064e3f9c682a764146b79a250 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,9 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import base +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util def _split_variable_closure(variable): @@ -44,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(base.Checkpointable): +class SaveTensorSlicesAsDeps(base.Trackable): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -56,17 +56,17 @@ class SaveTensorSlicesAsDeps(base.Checkpointable): consume_restore_buffer_fn=_combine_variable_closure( self.combined)) for name, dep in split_dependencies.items(): - self._track_checkpointable(dep, name=name) + self._track_trackable(dep, name=name) -class HasRegularDeps(tracking.AutoCheckpointable): +class HasRegularDeps(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(tracking.AutoCheckpointable): +class OnlyOneDep(tracking.AutoTrackable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index bac071c4cff383f60b707b6e42c13faf5e0ac948..faf90f018476b3c70a7bfa1346a5b590edbbddcd 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.tracking import util as trackable_utils def dot_graph_from_checkpoint(save_path): @@ -51,7 +51,7 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - object_graph = checkpointable_utils.object_metadata(save_path) + object_graph = trackable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() graph = 'digraph {\n' @@ -63,7 +63,7 @@ def dot_graph_from_checkpoint(save_path): slot_ids.add(slot_reference.slot_variable_node_id) for node_id, node in enumerate(object_graph.nodes): if (len(node.attributes) == 1 - and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + and node.attributes[0].name == trackable.VARIABLE_VALUE_KEY): if node_id in slot_ids: color = 'orange' tooltip_prefix = 'Slot variable' diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 583e3bc442893d825c337d73fb999d1e586738a1..98a22d573fdb6172cde100df461d9ae520c2c483 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -28,7 +28,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils try: import pydot # pylint: disable=g-import-not-at-top @@ -57,7 +57,7 @@ class DotGraphTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = resource_variable_ops.ResourceVariable(12) - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) optimizer.minimize(functools.partial(model, input_value)) checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index e570c09ecb5e64130ed6f3375a51d74850cc3989..30b4e2dbdee1117df12ae7ab8ce902e667234fb0 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) +set(GRPC_TAG 62688b6a05cc85b47fb77dd408611734253e47e2) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 21ae9a08a6bb8f71e5935ddde2d7bb3ed0cd8bbc..3d86ab9abbb4cc90c406edc6237c0d2abe440122 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -13,6 +13,7 @@ tensorflow/core/lib tensorflow/core/lib/core tensorflow/core/profiler tensorflow/core/protobuf +tensorflow/core/protobuf/tpu tensorflow/core/util tensorflow/examples tensorflow/examples/tutorials @@ -71,7 +72,7 @@ tensorflow/python/tools tensorflow/python/tools/api tensorflow/python/tools/api/generator tensorflow/python/training -tensorflow/python/training/checkpointable +tensorflow/python/training/tracking tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -437,7 +438,6 @@ tensorflow/contrib/timeseries/python/timeseries/state_space_models tensorflow/contrib/tpu tensorflow/contrib/tpu/ops tensorflow/contrib/tpu/profiler -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/python tensorflow/contrib/tpu/python/ops tensorflow/contrib/tpu/python/profiler diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 013180c89083748b240ad061b342300e886d3568..b4603206da419f44af0857b9b933eb7df1b255ff 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,6 +1,7 @@ tensorflow/core tensorflow/core/kernels/boosted_trees tensorflow/core/profiler +tensorflow/core/protobuf/tpu tensorflow/python tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels @@ -12,7 +13,6 @@ tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto tensorflow/contrib/tensorboard/plugins/projector -tensorflow/contrib/tpu/proto tensorflow/contrib/tpu/profiler tensorflow/contrib/training/python/training tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index d8d1cc3aa2ca4fff3c950654b7cbd7085c76010c..cc263d7995c01100f1c51436bcb584b600c8c161 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -125,9 +125,9 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/core/protobuf/tpu/*.proto" "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index e32097ceddfec95b8677fc762d641d09078e5343..79c61589112b739837b401010690e7f4ca917d07 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -23,6 +23,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":xla", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", ], diff --git a/tensorflow/contrib/compiler/__init__.py b/tensorflow/contrib/compiler/__init__.py index c4937dadfb8be3211377f0ae7017b95e7642dab0..797e5e8164e231e8b3806d40b32774711879b050 100644 --- a/tensorflow/contrib/compiler/__init__.py +++ b/tensorflow/contrib/compiler/__init__.py @@ -19,3 +19,4 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.compiler import jit +from tensorflow.contrib.compiler import xla diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index ca92c31236a7a3882415834eb32a994a120b6d2d..403f30909520dc5cd5f5919af843291fe1400b91 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -58,7 +58,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -709,7 +709,7 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) -class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): +class CudnnRNNTestSaveRestoreTrackable(test_util.TensorFlowTestCase): def _VerifyCheckpoint( self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, @@ -718,7 +718,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") with ops.device("gpu:0"): cudnn_layer = cudnn_cell_fn() - cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + cudnn_checkpoint = trackable_utils.Checkpoint(cell=cudnn_layer) status = cudnn_checkpoint.restore(checkpoint_path) inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) @@ -726,7 +726,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() - restore_layer_checkpoint = checkpointable_utils.Checkpoint( + restore_layer_checkpoint = trackable_utils.Checkpoint( cell=restore_layer) status = restore_layer_checkpoint.restore(second_save_path) current_state = restore_layer.zero_state(1, dtypes.float32) @@ -742,7 +742,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): self.assertAllClose(self.evaluate(restore_layer_output), self.evaluate(cudnn_output)[-1, -1:, ...]) - def _CheckpointableSingleCellUnidirectionalTestTemplate( + def _TrackableSingleCellUnidirectionalTestTemplate( self, single_cell_fn, cudnn_cell_fn): # Single-layer cuDNN cells with object-based checkpointing should be # checkpoint compatible with either single CudnnCompatible cells or @@ -759,7 +759,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_cell_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -775,10 +775,10 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testLSTMCheckpointableSingleLayer(self): + def testLSTMTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( @@ -788,19 +788,19 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @test_util.run_in_graph_and_eager_modes - def testGRUCheckpointableSingleLayer(self): + def testGRUTrackableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION with self.assertRaises(NotImplementedError): # TODO(allenl): Implement object-based saving for GRUs and other cells. - self._CheckpointableSingleCellUnidirectionalTestTemplate( + self._TrackableSingleCellUnidirectionalTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), cudnn_cell_fn=functools.partial( cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, direction=direction, name="awesome_gru")) - def _CheckpointableMultiLayerTestTemplate( + def _TrackableMultiLayerTestTemplate( self, single_cell_fn, cudnn_cell_fn, num_layers): def _MultiCellFn(): @@ -819,7 +819,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): value = np.random.normal(size=variable.shape) expected_values.append(value) self.evaluate(variable.assign(value)) - save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + save_checkpoint = trackable_utils.Checkpoint(cell=save_layer) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") first_save_path = save_checkpoint.save(checkpoint_prefix) @@ -837,7 +837,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): num_units = 2 num_layers = 3 direction = CUDNN_RNN_UNIDIRECTION - self._CheckpointableMultiLayerTestTemplate( + self._TrackableMultiLayerTestTemplate( single_cell_fn=functools.partial( cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), cudnn_cell_fn=functools.partial( diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 86ad8ae8073714657c78badb1e0b4a6d8c8ed5f0..1cb477716dfc6a9cc793939059784f9d89bcdd8a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -518,8 +518,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) - self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access - checkpointable=self, dtype=self._plain_dtype) + self._saveable._add_trackable_dependencies( # pylint: disable=protected-access + trackable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index f36e8d5022bc7e3f8268a161089153e5510dffc6..7d848e2ec2d99cd2a78ff3e813207c0cd5bb97cf 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking as checkpointable_lib +from tensorflow.python.training.tracking import tracking as trackable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -737,13 +737,13 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) - def _checkpointable_save(self, save_buffer): + def _trackable_save(self, save_buffer): weights, biases = self.format_converter.opaque_to_tf_canonical( self._variables) for name, tensor in zip(self._param_names, weights + biases): save_buffer[name] = array_ops.identity(tensor) - def _checkpointable_restore(self, restore_buffer): + def _trackable_restore(self, restore_buffer): tensors = [ array_ops.identity(restore_buffer[name]) for name in self._param_names ] @@ -752,26 +752,26 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): restored_shapes=None # Unused ) - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. + def _add_trackable_dependencies(self, trackable, dtype): + """Add canonical weight dependencies to `trackable`. When saving or restoring, converts to or from the opaque buffer format. Weights are saved and loaded in the configuration expected by cuDNN-compatible cells. Args: - checkpointable: An object inheriting from `CheckpointableBase` to add + trackable: An object inheriting from `Trackable` to add dependencies too (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) + fill_save_buffer_fn=self._trackable_save, + consume_restore_buffer_fn=self._trackable_restore) + self._trackable_track_params(trackable, split_dependencies) - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Tracks parameters in a canonical configuration.""" return # NotImplementedError raised by the Layer. @@ -819,7 +819,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_weights_names.append(prefix + "/kernel") tf_bias_names.append(prefix + "/bias") - def _checkpointable_track_params(self, checkpointable, params): + def _trackable_track_params(self, trackable, params): """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" biases = [] weights = [] @@ -833,12 +833,12 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # wrapping. kernel, = weights # pylint: disable=unbalanced-tuple-unpacking bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + trackable._track_trackable(kernel, name="kernel") # pylint: disable=protected-access + trackable._track_trackable(bias, name="bias") # pylint: disable=protected-access assert len(biases) == len(weights) for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.AutoCheckpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell = trackable_lib.AutoTrackable() + trackable._track_trackable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access cell.bias = bias cell.kernel = kernel diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 6c5f8c6b00975b3fba041271309a93cecd9f5057..4db711c1f3f2815e7b8cf275af315c062ce4c02e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -25,11 +25,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class AssertElementShapeTest(test_base.DatasetTestBase): def test_assert_element_shape(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index b9840b1ff1a3df5a05db0e64f436637220f49f80..220f9934b67d1d2a97f6c0fd4ba7779f011e1b09 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -27,12 +27,14 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.platform import test from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py index e7281d531870c75c638b5c48fa3fc6dc606a3623..78019fcc7d810da444f1407f3885d54e76a741c6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py @@ -25,10 +25,12 @@ from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 2527706709fae8e459aca3489324d4db3c784be6..9275a36582a8c82b936659041129b71e100f883e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -26,11 +26,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 509eb78128d062c7ea44730c2797b7c919cd0d69..2ab94d00565376bfebd80ee61094831e09ed3e68 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -131,6 +131,16 @@ py_library( ], ) +cuda_py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + additional_deps = [ + ":strategy_test_lib", + ":combinations", + "//tensorflow/python/eager:test", + ], +) + py_library( name = "collective_all_reduce_strategy", srcs = ["collective_all_reduce_strategy.py"], @@ -192,18 +202,6 @@ py_test( ], ) -py_test( - name = "one_device_strategy_test", - srcs = ["one_device_strategy_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":strategy_test_lib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/distribute:one_device_strategy", - "//tensorflow/python/eager:test", - ], -) - # TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", @@ -517,6 +515,7 @@ cuda_py_test( name = "cross_device_ops_test", srcs = ["cross_device_ops_test.py"], additional_deps = [ + ":collective_all_reduce_strategy", ":combinations", ":multi_worker_test_base", ":mirrored_strategy", @@ -540,6 +539,7 @@ py_library( srcs = [ "keras_backward_compat_test.py", "keras_test.py", + "keras_utils_test.py", ], deps = [ ":combinations", @@ -572,6 +572,24 @@ distribute_py_test( ], ) +distribute_py_test( + name = "keras_utils_test", + srcs = ["keras_utils_test.py"], + full_precision = True, + main = "keras_utils_test.py", + shard_count = 32, + tags = [ + "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. + "no_windows_gpu", + "notsan", + ], + deps = [ + ":keras_test", + ":keras_test_lib", + ], +) + # TODO(b/121200287): Remove this in 2.0 distribute_py_test( name = "keras_backward_compat_test", @@ -783,6 +801,6 @@ tf_xla_py_test( ":tpu_strategy", "//tensorflow/compiler/tests:xla_test", "//tensorflow/python/eager:test", - "//tensorflow/python/training/checkpointable:util", + "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpointing_test.py b/tensorflow/contrib/distribute/python/checkpointing_test.py index aa5b9f57b8a5bc12ee94399ec1fc5a55177a5b5d..eadf7233f2ae5ee50b71836ebfcc895163124ac2 100644 --- a/tensorflow/contrib/distribute/python/checkpointing_test.py +++ b/tensorflow/contrib/distribute/python/checkpointing_test.py @@ -30,15 +30,15 @@ from tensorflow.python.platform import test from tensorflow.python.training import adam as adam_v1 from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util as trackable_utils -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + super(NonLayerTrackable, self).__init__() + self.a_variable = trackable_utils.add_variable( self, name="a_variable", shape=[]) @@ -49,8 +49,8 @@ class Subclassed(training.Model): super(Subclassed, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -76,7 +76,7 @@ class TrainingCheckpointTests(xla_test.XLATestCase): with strategy.scope(): model = Subclassed() optimizer = adam_v1.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = trackable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(checkpoint_management.latest_checkpoint( diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index acbe4677b401cbea4fd0ec415415f25c920e68e4..ee7640dd1cea15e62ae9912ebedbd853778364a6 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -410,6 +410,7 @@ class DistributedCollectiveAllReduceStrategyTest( num_gpus=num_gpus, use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( combinations.combine( @@ -418,7 +419,8 @@ class DistributedCollectiveAllReduceStrategyTest( required_gpus=1, use_dataset=[True, False], use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy): + def DISABLED_testMakeInputFnIterator(self, num_gpus, use_dataset, + use_core_strategy): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -553,7 +555,7 @@ class LocalCollectiveAllReduceStrategy( required_gpus=2, use_dataset=[True, False], use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, use_dataset, use_core_strategy): + def DISABLED_testMakeInputFnIterator(self, use_dataset, use_core_strategy): num_gpus = 2 if use_dataset: fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 798a1591c73c4f4f3f37b015d20ec31c40aaa939..7c0f8033fbc046580bc46f90ee9945ffa2a718f9 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -352,6 +352,9 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) +one_device_strategy_gpu = NamedDistribution( + "OneDeviceGPU", lambda: one_device_lib.OneDeviceStrategy("/gpu:0"), + required_gpus=1) tpu_strategy = NamedDistribution( "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py index 54cce2988383fcf5e063726948fbbf62c7094ce5..2b8e0197961ae37b67dc8958054a03e164242dcd 100644 --- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -23,6 +23,7 @@ import itertools from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base @@ -204,15 +205,15 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): reduction_to_one_combinations = combinations.combine( cross_device_ops=[ combinations.NamedObject( - "DefaultReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "DefaultReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDevice( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ @@ -228,20 +229,23 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.NamedObject( "AllReduce", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_device_ops_lib.AllReduceCrossDeviceOps( - "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + combinations.NamedObject("NcclAllReduce", + cross_device_ops_lib.NcclAllReduce()), + combinations.NamedObject( + "HierarchicalCopy", + cross_device_ops_lib.HierarchicalCopyAllReduce(8)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus, - combinations.core_mirrored_strategy_with_two_gpus], + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus + ], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) @@ -306,8 +310,8 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): combinations.combine( cross_device_ops_instance=[ combinations.NamedObject( - "ReductionToOneDeviceCrossDeviceOps", - cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + "ReductionToOneDevice", + cross_device_ops_lib.ReductionToOneDevice()), combinations.NamedObject( "AllReduceCrossDeviceOps", cross_device_ops_lib.AllReduceCrossDeviceOps()) @@ -426,6 +430,9 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_device_ops, distribution) +NUM_WORKERS = 3 + + class MultiWorkerCollectiveAllReduceTest( multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): @@ -433,9 +440,9 @@ class MultiWorkerCollectiveAllReduceTest( @classmethod def setUpClass(cls): - """Create a local cluster with 2 workers.""" + """Create a local cluster with 3 workers.""" cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) + num_workers=NUM_WORKERS, num_ps=0) def setUp(self): super(MultiWorkerCollectiveAllReduceTest, self).setUp() @@ -443,7 +450,12 @@ class MultiWorkerCollectiveAllReduceTest( # collective key base for different tests. MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 - def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + def _get_test_objects(self, + task_type, + task_id, + num_gpus=0, + use_strategy_object=False, + local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, @@ -452,16 +464,24 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] else: devices = ["/device:CPU:0"] - return collective_all_reduce_ops, devices, "" + + if use_strategy_object: + # Still using contrib CollectiveAllReduceStrategy because we can specify + # num_gpus in its constructor. + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return strategy, devices, "" + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( - 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) @@ -469,8 +489,23 @@ class MultiWorkerCollectiveAllReduceTest( ] else: devices = ["/job:%s/task:%d" % (task_type, task_id)] - return (collective_all_reduce_ops, devices, - "grpc://" + self._cluster_spec[task_type][task_id]) + + if use_strategy_object: + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + strategy.configure( + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + strategy.extended._collective_keys = collective_keys + strategy.extended._cross_device_ops._collective_keys = collective_keys + return (strategy, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + NUM_WORKERS, num_gpus, collective_keys=collective_keys) + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) def _assert_values_equal(self, left, right, sess): if isinstance(left, list): @@ -490,9 +525,18 @@ class MultiWorkerCollectiveAllReduceTest( for l, r in zip(left_values, right_values): self.assertEqual(l, r) - def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + def _test_reduction(self, + task_type, + task_id, + num_gpus, + use_strategy_object=False, + local_mode=False): collective_all_reduce, devices, master_target = self._get_test_objects( - task_type, task_id, num_gpus, local_mode=local_mode) + task_type, + task_id, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=local_mode) if local_mode: num_workers = 1 worker_device = None @@ -500,6 +544,27 @@ class MultiWorkerCollectiveAllReduceTest( num_workers = len(self._cluster_spec.get("chief", [])) + len( self._cluster_spec.get("worker", [])) worker_device = "/job:%s/task:%d" % (task_type, task_id) + + def _reduce(test_object, reduce_op, per_replica, destinations): + if use_strategy_object: + with test_object.scope(): + # Mimic the behavior that distribution strategy usually strips the + # wrapper if there is only one value. + if len(per_replica.values) == 1: + per_replica = per_replica.values[0] + return test_object.extended.reduce_to(reduce_op, per_replica, + destinations) + else: + return test_object.reduce(reduce_op, per_replica, destinations) + + def _batch_reduce(test_object, reduce_op, value_destination_pairs): + if use_strategy_object: + with test_object.scope(): + return test_object.extended.batch_reduce_to(reduce_op, + value_destination_pairs) + else: + return test_object.batch_reduce(reduce_op, value_destination_pairs) + with ops.Graph().as_default(), \ ops.device(worker_device), \ self.cached_session(target=master_target) as sess: @@ -524,26 +589,30 @@ class MultiWorkerCollectiveAllReduceTest( # test reduce() for destinations in all_destinations: self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica, - destinations=destinations), - _fake_mirrored(mean, destinations), sess) + destinations=destinations), _fake_mirrored(mean, destinations), + sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.MEAN, per_replica_2, - destinations=destinations), - _fake_mirrored(mean_2, destinations), sess) + destinations=destinations), _fake_mirrored( + mean_2, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( - collective_all_reduce.reduce( + _reduce( + collective_all_reduce, reduce_util.ReduceOp.SUM, per_replica_2, destinations=destinations), @@ -553,17 +622,13 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, - [(per_replica, d1), - (per_replica_2, d2)]), - [ - _fake_mirrored(mean, d1), - _fake_mirrored(mean_2, d2) - ], sess) + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), + [_fake_mirrored(mean, d1), + _fake_mirrored(mean_2, d2)], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, - [(per_replica, d1), - (per_replica_2, d2)]), + _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) @@ -572,18 +637,36 @@ class MultiWorkerCollectiveAllReduceTest( return True @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) - def testReductionDistributed(self, num_gpus): + combinations.combine( + mode=["graph"], + num_gpus=[0, 1, 2], + required_gpus=1, + use_strategy_object=[True, False])) + def testReductionDistributed(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._run_between_graph_clients(self._test_reduction, self._cluster_spec, - num_gpus) + self._run_between_graph_clients( + self._test_reduction, + self._cluster_spec, + num_gpus, + use_strategy_object=use_strategy_object) # Collective ops doesn't support strategy with one device. - def testReductionLocal(self, num_gpus=2): + @combinations.generate( + combinations.combine( + mode=["graph"], + num_gpus=[2], + required_gpus=2, + use_strategy_object=[True, False])) + def testReductionLocal(self, num_gpus, use_strategy_object): if context.num_gpus() < num_gpus: return - self._test_reduction(None, None, num_gpus, local_mode=True) + self._test_reduction( + None, + None, + num_gpus, + use_strategy_object=use_strategy_object, + local_mode=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index e17085628ba6d1dfc79839fd824801723f07a518..1ff1e7c1d255492e0535175dae7594d2ceb4010b 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -22,7 +22,6 @@ import shutil import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.optimizer_v2 import adagrad @@ -117,7 +116,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) predictions = np.array([ x[prediction_keys.PredictionKeys.PREDICTIONS] diff --git a/tensorflow/contrib/distribute/python/input_lib_test.py b/tensorflow/contrib/distribute/python/input_lib_test.py index 10a58316ec5b3d9d968a88c5c39ff70c277daa65..204f52b034f2366a42fbdab41c467feddb5969a0 100644 --- a/tensorflow/contrib/distribute/python/input_lib_test.py +++ b/tensorflow/contrib/distribute/python/input_lib_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import input_lib @@ -214,33 +213,5 @@ class InputIteratorMultiWorkerTest( expected_values, sess) -class SplitDatasetBatchTest(test.TestCase): - - def testBatchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testMapAndBatchDataset(self): - dataset = dataset_ops.Dataset.range(100) - dataset = dataset.apply(batching.map_and_batch(lambda x: x, 20)) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - def testPrefetchDataset(self): - dataset = dataset_ops.Dataset.range(100).batch(20).prefetch(1) - split_batch_by = 2 - result_dataset = input_lib._split_dataset_batch(dataset, split_batch_by) - expected_values = [range(i, i+10) for i in range(0, 100, 10)] - result = [self.evaluate(el) for el in result_dataset] - self.assertAllEqual(expected_values, result) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py index 9a581e7141af4a6625246539bc48835e6a920887..c49b5522f9135efd9ae3005e92099caf54a76a3a 100644 --- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -31,10 +31,10 @@ from tensorflow.python.framework import random_seed from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop -from tensorflow.python.training.mode_keys import ModeKeys _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 diff --git a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py index dae32188917cce9209b8e51032ef808352bc257c..61202e30c4f33892d2675080fae07cc4d7102337 100644 --- a/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_dnn_correctness_test.py @@ -47,7 +47,9 @@ class TestDistributionStrategyDnnCorrectness( # We add few non-linear layers to make it non-trivial. model = keras.Sequential() model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) - model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense( + 10, activation='relu', + kernel_regularizer=keras.regularizers.l2(1e-4))) model.add(keras.layers.Dense(10, activation='relu')) model.add(keras.layers.Dense(1)) diff --git a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py index f625664372dfb6814ccbe9539f6abe018d2a4447..3c2961456b2eede9570ce29f7a8900834f2ccfb7 100644 --- a/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py +++ b/tensorflow/contrib/distribute/python/keras_image_model_correctness_test.py @@ -23,7 +23,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import keras_correctness_test_base from tensorflow.python import keras from tensorflow.python.eager import test -from tensorflow.python.training import gradient_descent +from tensorflow.python.keras.optimizer_v2 import gradient_descent class DistributionStrategyCnnCorrectnessTest( @@ -33,7 +33,8 @@ class DistributionStrategyCnnCorrectnessTest( with keras_correctness_test_base.MaybeDistributionScope(distribution): image = keras.layers.Input(shape=(28, 28, 3), name='image') c1 = keras.layers.Conv2D( - name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4))( + name='conv1', filters=16, kernel_size=(3, 3), strides=(4, 4), + kernel_regularizer=keras.regularizers.l2(1e-4))( image) if self.with_batch_norm: c1 = keras.layers.BatchNormalization(name='bn1')(c1) @@ -47,7 +48,7 @@ class DistributionStrategyCnnCorrectnessTest( model.set_weights(initial_weights) model.compile( - optimizer=gradient_descent.GradientDescentOptimizer( + optimizer=gradient_descent.SGD( learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index 952b11932b83d16558ac9f5ce780886d94e72744..c93d7afa7ceef2c9c272e91997e2871655cea079 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -64,7 +64,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): def loss_fn(): replica_id = _replica_id() - return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * var + return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var train_op = optimizer.minimize(loss_fn, var_list=[var]) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index dd975c6c36d5d5387035e9da4170e4072406d79c..77e241974f7c4c27382ab548a202891fdbbc6ba0 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os -import tempfile from absl.testing import parameterized import numpy as np @@ -27,17 +25,17 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.python import keras +from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import values from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -72,6 +70,20 @@ def simple_functional_model(): return model +def simple_subclassed_model(num_labels=_NUM_CLASS): + + class _SimpleMLP(keras.Model): + + def __init__(self, num_labels): + super(_SimpleMLP, self).__init__() + self.dense = keras.layers.Dense(num_labels) + + def call(self, inputs): + return self.dense(inputs) + + return _SimpleMLP(num_labels) + + def simple_multi_inputs_multi_outputs_model(): input_a = keras.layers.Input(shape=(16,), name='input_a') input_b = keras.layers.Input(shape=(16,), name='input_b') @@ -216,6 +228,22 @@ def get_predict_dataset(distribution): return dataset +def convert_numpy_to_dataset_with_unknown_cardinality(inputs, + targets=None): + if targets is not None: + input_slices = (inputs, targets) + dummy_op = (lambda inp, target: True) + else: + input_slices = inputs + dummy_op = (lambda inp: True) + + original_dataset = (dataset_ops.Dataset.from_tensor_slices( + input_slices)) + ds_with_unknown_cardinality = (original_dataset.filter(dummy_op). + batch(10, drop_remainder=True)) + return ds_with_unknown_cardinality + + def multi_input_output_model(): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(5,), name='input_b') @@ -230,9 +258,12 @@ def multi_input_output_model(): return model +# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with +# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict. strategies_minus_tpu = [ combinations.default_strategy, combinations.one_device_strategy, + combinations.one_device_strategy_gpu, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, combinations.core_mirrored_strategy_with_gpu_and_cpu, @@ -244,15 +275,13 @@ tpu_strategies = [ def strategy_minus_tpu_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager']) + return combinations.combine(distribution=strategies_minus_tpu, + mode=['graph', 'eager']) def tpu_strategy_combinations(): - return combinations.combine( - distribution=tpu_strategies, - mode=['graph']) + return combinations.combine(distribution=tpu_strategies, + mode=['graph']) def all_strategy_combinations(): @@ -263,6 +292,7 @@ def all_strategy_combinations_minus_default(): strategy_minus_default_combinations = combinations.combine( distribution=[ combinations.one_device_strategy, + combinations.one_device_strategy_gpu, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, combinations.core_mirrored_strategy_with_gpu_and_cpu, @@ -286,12 +316,6 @@ def strategy_and_optimizer_combinations(): ])) -def strategy_for_numpy_input_combinations(): - return combinations.combine( - distribution=strategies_minus_tpu + tpu_strategies, - mode=['graph']) - - class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -447,7 +471,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies # that use per_core_batch_size @@ -478,7 +502,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=None, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_no_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -524,7 +548,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_63_samples, steps=1, batch_size=None) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_no_steps_with_batch_size(self, distribution): # Calculate the per_replica_batch_size scaling factor for strategies @@ -558,7 +582,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=None, batch_size=3) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calculating_input_params_with_steps_with_batch_size(self, distribution): with self.cached_session(): @@ -575,7 +599,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, distributed_training_utils.get_input_params( distribution, input_64_samples, steps=10, batch_size=13) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): with distribution.scope(): @@ -606,7 +630,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): with distribution.scope(): @@ -658,7 +682,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_for_numpy_input_combinations()) + @combinations.generate(all_strategy_combinations()) def test_flatten_predict_outputs(self, distribution): with self.cached_session(): with distribution.scope(): @@ -841,9 +865,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - # TODO(b/122743976): Include TPUStrategy for this test as well once - # step inference is supported. - @combinations.generate(strategy_minus_tpu_combinations()) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset_without_steps( self, distribution): with self.cached_session(): @@ -864,7 +886,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, predict_with_numpy = model.predict(inputs, batch_size=10) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.batch(10) + dataset = dataset.batch(10, drop_remainder=True) fit_with_ds = model.fit(dataset, epochs=1).history eval_with_ds = model.evaluate(dataset) predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) @@ -877,6 +899,61 @@ class TestDistributionStrategyWithDatasets(test.TestCase, self.assertAllClose( predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + @combinations.generate(all_strategy_combinations()) + def test_on_dataset_with_unknown_cardinality_without_steps( + self, distribution): + with self.cached_session(): + with distribution.scope(): + model = get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((1000, 3), dtype=np.float32) + targets = np.zeros((1000, 4), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + fit_with_numpy = model.fit(inputs, targets, epochs=1, + batch_size=10).history + fit_with_numpy_multiple_epochs = model.fit( + inputs, targets, epochs=2, batch_size=10).history + eval_with_numpy = model.evaluate(inputs, targets, batch_size=10) + predict_with_numpy = model.predict(inputs, batch_size=10) + + dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs, targets) + predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs) + + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + dataset)), cardinality.UNKNOWN) + self.assertEqual(keras.backend.get_value(cardinality.cardinality( + predict_dataset)), cardinality.UNKNOWN) + + eval_with_ds = model.evaluate(dataset) + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + + if (distributed_training_utils.is_tpu_strategy(distribution) and + distribution.extended.steps_per_run != 1): + with self.assertRaisesRegexp(ValueError, '`steps_per_epoch` ' + 'should be specified'): + fit_with_ds = model.fit(dataset, epochs=1) + else: + fit_with_ds = model.fit(dataset, + epochs=1).history + fit_with_ds_multiple_epochs = model.fit(dataset, + epochs=2).history + self.assertAllClose( + fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4) + self.assertAllClose( + fit_with_numpy_multiple_epochs, + fit_with_ds_multiple_epochs, atol=1e-4, rtol=1e-4) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): @@ -1123,397 +1200,126 @@ class TestDistributionStrategyWithDatasets(test.TestCase, atol=1e-4, rtol=1e-4) -class Counter(keras.callbacks.Callback): - """Counts the number of times each callback method was run. +class TestRegularizerLoss(test.TestCase, parameterized.TestCase): + class IdentityRegularizer(keras.regularizers.Regularizer): - Attributes: - method_counts: dict. Contains the counts of time each callback method was - run. - """ + def __call__(self, x): + return array_ops.identity(x) - def __init__(self): - self.method_counts = collections.defaultdict(int) - methods_to_count = [ - 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', - 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', - 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', - 'on_test_begin', 'on_test_end', 'on_train_batch_begin', - 'on_train_batch_end', 'on_train_begin', 'on_train_end' - ] - for method_name in methods_to_count: - setattr(self, method_name, - self.wrap_with_counts(method_name, getattr(self, method_name))) + class AddLayer(keras.layers.Layer): - def wrap_with_counts(self, method_name, method): + def build(self, _): + self.v = self.add_weight( + 'v', (), initializer='ones', + regularizer=TestRegularizerLoss.IdentityRegularizer()) - def _call_and_count(*args, **kwargs): - self.method_counts[method_name] += 1 - return method(*args, **kwargs) + def call(self, inputs): + return inputs + self.v - return _call_and_count + @staticmethod + def loss_fn(_, y_pred): + return math_ops.reduce_mean(y_pred) - -class TestDistributionStrategyWithCallbacks(test.TestCase, - parameterized.TestCase): - - @combinations.generate(all_strategy_combinations()) - def test_callbacks_in_fit(self, distribution): + @combinations.generate(all_strategy_combinations_minus_default()) + def test_regularizer_loss(self, distribution): + batch_size = 2 + if not distributed_training_utils.global_batch_size_supported(distribution): + batch_size //= distribution.num_replicas_in_sync + + # Given an input x, which is always 1, and variable v, this model computes + # Loss=x+v+regularizer_loss, where regularizer_loss=v and the variable is + # initialized to 1. Therefore, this model computes Loss=1+2v, and so the + # gradient dLoss/dv = 2. This gradient of 2 is averaged over all examples + # in a batch and then multiplied by the learning rate of 1. As a result, + # the model update for one batch should subtract 2 from v, resulting in v + # being -1. If the regularizer loss is not scaled correctly by number of + # replicas, the variable value will be incorrect when number of replicas + # >1. For e.g. it will be -2 if num replicas = 2. with distribution.scope(): - model = get_model() - model.compile(optimizer='sgd', loss='mse', metrics=['mae']) - - dataset = get_dataset(distribution) - counter = Counter() - - epochs = 2 - steps_per_epoch = 5 - validation_steps = 3 - - model.fit( - dataset, - epochs=epochs, - steps_per_epoch=steps_per_epoch, - verbose=0, - validation_data=dataset, - validation_steps=validation_steps, - callbacks=[counter]) - - if isinstance(distribution, tpu_strategy.TPUStrategy): - # TPU Strategy can have multi step training, from extended.steps_per_run - # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch - steps_per_run = distribution.extended.steps_per_run - num_batch_call_per_epoch = steps_per_epoch // steps_per_run - if steps_per_epoch % steps_per_run: - num_batch_call_per_epoch += 1 - else: - num_batch_call_per_epoch = steps_per_epoch - - self.assertDictEqual( - counter.method_counts, { - 'on_batch_begin': epochs * num_batch_call_per_epoch, - 'on_batch_end': epochs * num_batch_call_per_epoch, - 'on_epoch_begin': epochs, - 'on_epoch_end': epochs, - 'on_test_batch_begin': epochs * validation_steps, - 'on_test_batch_end': epochs * validation_steps, - 'on_test_begin': epochs, - 'on_test_end': epochs, - 'on_train_batch_begin': epochs * num_batch_call_per_epoch, - 'on_train_batch_end': epochs * num_batch_call_per_epoch, - 'on_train_begin': 1, - 'on_train_end': 1 - }) + x = keras.layers.Input(shape=(), batch_size=batch_size) + y = TestRegularizerLoss.AddLayer()(x) + model = keras.models.Model(inputs=x, outputs=y) + opt = gradient_descent_keras.SGD(1.) + model.compile(opt, loss=TestRegularizerLoss.loss_fn) + model.fit( + x=np.array([[1.], [1.]], dtype=np.float32), + y=np.array([[1.], [1.]], dtype=np.float32), + batch_size=batch_size) + v = model.get_weights()[0] + self.assertEqual(-1.0, v) + + +class TestDistributionStrategyWithKerasModels(test.TestCase, + parameterized.TestCase): @combinations.generate(all_strategy_combinations()) - def test_callbacks_in_eval(self, distribution): + def test_distribution_strategy_on_sequential_model(self, distribution): with distribution.scope(): - model = get_model() - model.compile(optimizer='sgd', loss='mse', metrics=['mae']) - - dataset = get_dataset(distribution) - counter = Counter() + model = simple_sequential_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - model.evaluate(dataset, steps=5, callbacks=[counter]) + inputs = np.zeros((20, 10), np.float32) + targets = np.zeros((20, 2), np.float32) - self.assertDictEqual( - counter.method_counts, { - 'on_test_batch_begin': 5, - 'on_test_batch_end': 5, - 'on_test_begin': 1, - 'on_test_end': 1 - }) + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) @combinations.generate(all_strategy_combinations()) - def test_callbacks_in_predict(self, distribution): + def test_distribution_strategy_on_functional_model(self, distribution): with distribution.scope(): model = get_model() - model.compile(optimizer='sgd', loss='mse', metrics=['mae']) - - dataset = get_dataset(distribution) - counter = Counter() - - model.predict(get_predict_dataset(dataset), steps=5, callbacks=[counter]) - - self.assertDictEqual( - counter.method_counts, { - 'on_predict_batch_begin': 5, - 'on_predict_batch_end': 5, - 'on_predict_begin': 1, - 'on_predict_end': 1 - }) - - -class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_shape_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2)) - b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - # Removed device and input tensor shape details from the error message - # since the order of the device and the corresponding input tensor shape - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor shapes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - with distribution.scope(): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self, - distribution): - with self.cached_session(): - a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) - b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) - device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) - x = values.DistributedValues(device_map, (a, b)) - y = values.DistributedValues(device_map, (a, a)) - # Removed device and input tensor dtype details from the error message - # since the order of the device and the corresponding input tensor dtype - # is not deterministic over different runs. - with self.assertRaisesRegexp(ValueError, - 'Input tensor dtypes do not match for ' - 'distributed tensor inputs ' - 'DistributedValues:.+'): - with distribution.scope(): - distributed_training_utils.validate_distributed_dataset_inputs( - distribution, x, y) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_unsupported_features(self, distribution): - with self.cached_session(): - with distribution.scope(): - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics) - - dataset = get_dataset(distribution) - - # Test with validation split - with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not ' - 'supported when input `x` is a dataset or a ' - 'dataset iterator.+'): - model.fit(dataset, - epochs=1, steps_per_epoch=2, verbose=0, - validation_split=0.5, validation_steps=2) - - # Test with sample weight. - sample_weight = np.random.random((10,)) - with self.assertRaisesRegexp( - ValueError, '`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator.'): - model.fit( - dataset, - epochs=1, - steps_per_epoch=2, - verbose=0, - sample_weight=sample_weight) - - # Test with not specifying the `steps` argument for dataset with infinite - # cardinality. - dataset = dataset.repeat() - with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps_per_epoch` argument'): - model.fit(dataset, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps` argument'): - model.evaluate(dataset, verbose=0) - - with self.assertRaisesRegexp(ValueError, 'When passing an infinitely ' - 'repeating dataset, you must specify the ' - '`steps` argument'): - model.predict(dataset, verbose=0) - - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_calling_with_unsupported_predefined_callbacks(self, distribution): - with self.cached_session(): - with distribution.scope(): - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics) - - dataset = get_dataset(distribution) - - def schedule(_): - return 0.001 - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) - - with self.assertRaisesRegexp(ValueError, - 'You must specify a Keras Optimizer V2 when ' - 'using'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.ReduceLROnPlateau()]) - - -class TestDistributionStrategyWithLossMasking(test.TestCase, - parameterized.TestCase): - - # TODO(priyag): Enable all strategies for this test. Currently it does not - # work for TPU due to some invalid datatype. - @combinations.generate(combinations.combine( - distribution=[ - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.core_mirrored_strategy_with_gpu_and_cpu], - mode=['graph', 'eager'])) - def test_masking(self, distribution): - with self.cached_session(): - np.random.seed(1337) - x = np.array([[[1], [1]], [[0], [0]]]) - with distribution.scope(): - model = keras.models.Sequential() - model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01)) - y = np.array([[[1], [1]], [[1], [1]]]) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) - self.assertEqual(hist.history['loss'][0], 0) - - -class TestDistributionStrategyWithNormalizationLayer( - test.TestCase, parameterized.TestCase): - - @combinations.generate(combinations.times( - all_strategy_combinations(), - combinations.combine(fused=[True, False]))) - def test_batchnorm_correctness(self, distribution, fused): - with self.cached_session(): - with distribution.scope(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization( - input_shape=(10,), momentum=0.8, fused=fused) - model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01)) - - # centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) - x = x.astype('float32') - dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 32, distribution) - - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) - predict_dataset = predict_dataset.repeat(100) - predict_dataset = batch_wrapper(predict_dataset, 32, distribution) - - model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) - out = model.predict(predict_dataset, steps=2) - out -= keras.backend.eval(norm.beta) - out /= keras.backend.eval(norm.gamma) - np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) - np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) - - -class TestDistributionStrategySaveLoadWeights(test.TestCase, - parameterized.TestCase): - - @combinations.generate(all_strategy_combinations_minus_default()) - def test_save_load_h5(self, distribution): - with self.cached_session(): - dataset = get_dataset(distribution) - with distribution.scope(): - model = get_model() - model.compile(gradient_descent_keras.SGD(0.01), 'mse') - model.fit(dataset, epochs=1, steps_per_epoch=1) + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - weights_file = tempfile.mktemp('.h5') - model.save_weights(weights_file) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) - model_2 = get_model() - model_2.compile(gradient_descent_keras.SGD(0.01), 'mse') - model_2.load_weights(weights_file) - model_2.predict(get_predict_dataset(distribution), steps=2) - model_2.fit(dataset, epochs=1, steps_per_epoch=1) + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) + model.predict(inputs, steps=1) + model.evaluate(inputs, targets, steps=1) + # TODO(b/124377929): Remove error assertions once subclassed models + # are supported in DistributedStrategy. @combinations.generate(all_strategy_combinations_minus_default()) - def test_save_load_checkpointable(self, distribution): - # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 - with self.cached_session(): - dataset = get_dataset(distribution) - with distribution.scope(): - model = get_model() - model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') - model.fit(dataset, epochs=1, steps_per_epoch=1) - - weights_file = tempfile.mktemp() - model.save_weights(weights_file) + def test_distribution_strategy_on_subclassed_model(self, distribution): + with distribution.scope(): + model = simple_subclassed_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss) - model_2 = get_model() - model_2.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') - model_2.load_weights(weights_file) - model_2.predict(get_predict_dataset(distribution), steps=2) - model_2.fit(dataset, epochs=1, steps_per_epoch=1) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 2), dtype=np.float32) + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.fit(inputs, targets, epochs=1, steps_per_epoch=2) -class TestDistributionStrategyValidation(test.TestCase, - parameterized.TestCase): + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.predict(inputs, steps=1) - @combinations.generate(all_strategy_combinations_minus_default()) - def test_layer_outside_scope(self, distribution): - with self.cached_session(): - with self.assertRaisesRegexp( - ValueError, 'was not created in the distribution strategy'): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - with distribution.scope(): - model = keras.Model(x, y) - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics) + with self.assertRaisesRegexp(AttributeError, 'has no attribute'): + model.evaluate(inputs, targets, steps=1) @combinations.generate(all_strategy_combinations_minus_default()) - def test_model_outside_scope(self, distribution): - with self.cached_session(): - with self.assertRaisesRegexp( - ValueError, 'was not created in the distribution strategy'): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) - with distribution.scope(): - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics) + def test_distribution_strategy_one_dimensional(self, distribution): + with distribution.scope(): + inp = keras.layers.Input(shape=(10,)) + out = keras.layers.Dense(3, activation='softmax')(inp) + model = keras.Model(inputs=[inp], outputs=[out]) + model.compile( + optimizer='rmsprop', + loss='sparse_categorical_crossentropy', + metrics=['sparse_categorical_accuracy'], + ) + + x = np.random.random((64, 10)).astype('float32') + y = np.random.randint(3, size=64) + + model.fit(x, y, epochs=1, steps_per_epoch=2) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/keras_utils_test.py b/tensorflow/contrib/distribute/python/keras_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..36eaee77f21a9f6d62a7c3f616d0126b7a4a8902 --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_utils_test.py @@ -0,0 +1,471 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.keras models with callbacks, checkpointing with dist strategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import tempfile +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import keras_test as keras_test_lib +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras +from tensorflow.python.training import gradient_descent + + +class Counter(keras.callbacks.Callback): + """Counts the number of times each callback method was run. + + Attributes: + method_counts: dict. Contains the counts of time each callback method was + run. + """ + + def __init__(self): + self.method_counts = collections.defaultdict(int) + methods_to_count = [ + 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', + 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', + 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', + 'on_test_begin', 'on_test_end', 'on_train_batch_begin', + 'on_train_batch_end', 'on_train_begin', 'on_train_end' + ] + for method_name in methods_to_count: + setattr(self, method_name, + self.wrap_with_counts(method_name, getattr(self, method_name))) + + def wrap_with_counts(self, method_name, method): + + def _call_and_count(*args, **kwargs): + self.method_counts[method_name] += 1 + return method(*args, **kwargs) + + return _call_and_count + + +class TestDistributionStrategyWithCallbacks(test.TestCase, + parameterized.TestCase): + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_fit(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + epochs = 2 + steps_per_epoch = 5 + validation_steps = 3 + + model.fit( + dataset, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=0, + validation_data=dataset, + validation_steps=validation_steps, + callbacks=[counter]) + + if isinstance(distribution, tpu_strategy.TPUStrategy): + # TPU Strategy can have multi step training, from extended.steps_per_run + # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch + steps_per_run = distribution.extended.steps_per_run + num_batch_call_per_epoch = steps_per_epoch // steps_per_run + if steps_per_epoch % steps_per_run: + num_batch_call_per_epoch += 1 + else: + num_batch_call_per_epoch = steps_per_epoch + + self.assertDictEqual( + counter.method_counts, { + 'on_batch_begin': epochs * num_batch_call_per_epoch, + 'on_batch_end': epochs * num_batch_call_per_epoch, + 'on_epoch_begin': epochs, + 'on_epoch_end': epochs, + 'on_test_batch_begin': epochs * validation_steps, + 'on_test_batch_end': epochs * validation_steps, + 'on_test_begin': epochs, + 'on_test_end': epochs, + 'on_train_batch_begin': epochs * num_batch_call_per_epoch, + 'on_train_batch_end': epochs * num_batch_call_per_epoch, + 'on_train_begin': 1, + 'on_train_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_eval(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.evaluate(dataset, steps=5, callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_test_batch_begin': 5, + 'on_test_batch_end': 5, + 'on_test_begin': 1, + 'on_test_end': 1 + }) + + @combinations.generate(keras_test_lib.all_strategy_combinations()) + def test_callbacks_in_predict(self, distribution): + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(optimizer='sgd', loss='mse', metrics=['mae']) + + dataset = keras_test_lib.get_dataset(distribution) + counter = Counter() + + model.predict( + keras_test_lib.get_predict_dataset(dataset), + steps=5, + callbacks=[counter]) + + self.assertDictEqual( + counter.method_counts, { + 'on_predict_batch_begin': 5, + 'on_predict_batch_end': 5, + 'on_predict_begin': 1, + 'on_predict_end': 1 + }) + + +class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch( + self, distribution): + with self.cached_session(): + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + device_map = values.ReplicaDeviceMap(('/device:CPU:0', '/device:GPU:0')) + x = values.DistributedValues(device_map, (a, b)) + y = values.DistributedValues(device_map, (a, a)) + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp( + ValueError, 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + with distribution.scope(): + distributed_training_utils.validate_distributed_dataset_inputs( + distribution, x, y) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + validation_split=0.5, + validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument for dataset with infinite + # cardinality. + dataset = dataset.repeat() + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp( + ValueError, 'When passing an infinitely ' + 'repeating dataset, you must specify the ' + '`steps` argument'): + model.predict(dataset, verbose=0) + + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): + with self.cached_session(): + with distribution.scope(): + model = keras_test_lib.get_model() + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics) + + dataset = keras_test_lib.get_dataset(distribution) + + def schedule(_): + return 0.001 + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp( + ValueError, 'You must specify a Keras Optimizer V2 when ' + 'using'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + + +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): + + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. + @combinations.generate( + combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu + ], + mode=['graph', 'eager'])) + def test_masking(self, distribution): + with self.cached_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + with distribution.scope(): + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class TestDistributionStrategyWithNormalizationLayer(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.times(keras_test_lib.all_strategy_combinations(), + combinations.combine(fused=[True, False]))) + def test_batchnorm_correctness(self, distribution, fused): + with self.cached_session(): + with distribution.scope(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + input_shape=(10,), momentum=0.8, fused=fused) + model.add(norm) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01)) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution) + + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x) + predict_dataset = predict_dataset.repeat(100) + predict_dataset = keras_test_lib.batch_wrapper(predict_dataset, 32, + distribution) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(predict_dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class TestDistributionStrategySaveLoadWeights(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_h5(self, distribution): + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(gradient_descent_keras.SGD(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp('.h5') + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(gradient_descent_keras.SGD(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_save_load_trackable(self, distribution): + # TODO(sourabhbajaj): Test fails with optimizer v2 without h5 + with self.cached_session(): + dataset = keras_test_lib.get_dataset(distribution) + with distribution.scope(): + model = keras_test_lib.get_model() + model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model.fit(dataset, epochs=1, steps_per_epoch=1) + + weights_file = tempfile.mktemp() + model.save_weights(weights_file) + + model_2 = keras_test_lib.get_model() + model_2.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') + model_2.load_weights(weights_file) + model_2.predict( + keras_test_lib.get_predict_dataset(distribution), steps=2) + model_2.fit(dataset, epochs=1, steps_per_epoch=1) + + +class TestDistributionStrategyValidation(test.TestCase, parameterized.TestCase): + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_layer_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + with distribution.scope(): + model = keras.Model(x, y) + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + @combinations.generate( + keras_test_lib.all_strategy_combinations_minus_default()) + def test_model_outside_scope(self, distribution): + with self.cached_session(): + with self.assertRaisesRegexp( + ValueError, 'was not created in the distribution strategy'): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + with distribution.scope(): + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 0b8df787e6b1bde8dce30ea420a3f0e19da23ca4..5ce731816ccefe36c1f876c79589e448f00b86f5 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -116,7 +116,8 @@ class MirroredTwoDeviceDistributionTest( self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, expected_values) - def testMakeInputFnIteratorWithCallable(self, distribution): + # TODO(b/124344198): Re-enable after fixing this flaky test. + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): def fn(): dataset = dataset_ops.Dataset.range(2).interleave( (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) @@ -606,6 +607,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation="invalid") def testNonMatchingVariableCreation(self, distribution): + self.skipTest("b/123075960") def model_fn(name): v = variable_scope.variable(1.0, name=name) ds_context.get_replica_context().merge_call(lambda _: _) @@ -1454,7 +1456,7 @@ class MultiWorkerMirroredStrategyTest( self._test_input_fn_iterator( iterator, distribution.extended.worker_devices, expected_values, sess) - def testMakeInputFnIteratorWithCallable(self, distribution): + def DISABLED_testMakeInputFnIteratorWithCallable(self, distribution): self._configure_distribution_strategy(distribution) def fn(): dataset = dataset_ops.Dataset.range(100) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index b05aac431f65b4281d9ed9c2fa95c210d55f4008..7dca13a5b41d1a2db474c44c82f1da88be84df05 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -409,3 +409,25 @@ class IndependentWorkerTestBase(test.TestCase): def join_independent_workers(self, worker_threads): self._coord.join(worker_threads) + + +def get_tf_config_task(): + return json.loads(os.environ['TF_CONFIG'])['task'] + + +def get_tf_config_cluster_spec(): + return json.loads(os.environ['TF_CONFIG'])['cluster'] + + +def get_task_type(): + return get_tf_config_task()['type'] + + +def get_task_index(): + return get_tf_config_task()['index'] + + +def is_chief(): + return ('chief' not in get_tf_config_cluster_spec() + and get_task_type() == 'worker' + and get_task_index() == 0) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 906bffc8525688f63474c3f1fbc5d7f0a024431b..0e56f663d6a1ed7945befd933f2f4a83c5f64342 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -18,36 +18,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.framework import test_util +@combinations.generate(combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.one_device_strategy_gpu], + mode=["eager", "graph"])) class OneDeviceStrategyTest( strategy_test_lib.DistributionTestBase, strategy_test_lib.OneDeviceDistributionTestBase): - def _get_distribution_strategy(self): - return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testMakeInputFnIteratorWithDataset(self): - d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMakeInputFnIteratorWithDataset(self, distribution): dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] input_fn = self._input_fn_to_test_input_context( @@ -55,13 +54,11 @@ class OneDeviceStrategyTest( expected_num_replicas_in_sync=1, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - iterator = d.make_input_fn_iterator(input_fn) + iterator = distribution.make_input_fn_iterator(input_fn) self._test_input_fn_iterator( - iterator, d.extended.worker_devices, expected_values) + iterator, distribution.extended.worker_devices, expected_values) - @test_util.run_in_graph_and_eager_modes - def testMakeInputFnIteratorWithCallable(self): - d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + def testMakeInputFnIteratorWithCallable(self, distribution): def fn(): dataset = dataset_ops.Dataset.range(10) it = dataset.make_one_shot_iterator() @@ -72,32 +69,31 @@ class OneDeviceStrategyTest( expected_num_replicas_in_sync=1, expected_num_input_pipelines=1, expected_input_pipeline_id=0) - iterator = d.make_input_fn_iterator(input_fn) + iterator = distribution.make_input_fn_iterator(input_fn) self._test_input_fn_iterator( - iterator, d.extended.worker_devices, expected_values, + iterator, distribution.extended.worker_devices, expected_values, test_reinitialize=False) - @test_util.run_in_graph_and_eager_modes - def testNumpyIterator(self): - self._test_numpy_iterator(self._get_distribution_strategy()) + def testNumpyIterator(self, distribution): + self._test_numpy_iterator(distribution) - def testAllReduceSum(self): - self._test_all_reduce_sum(self._get_distribution_strategy()) + def testAllReduceSum(self, distribution): + self._test_all_reduce_sum(distribution) - def testAllReduceSumGradients(self): - self._test_all_reduce_sum_gradients(self._get_distribution_strategy()) + def testAllReduceSumGradients(self, distribution): + self._test_all_reduce_sum_gradients(distribution) - def testAllReduceSumGradientTape(self): - self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy()) + def testAllReduceSumGradientTape(self, distribution): + self._test_all_reduce_sum_gradient_tape(distribution) - def testAllReduceMean(self): - self._test_all_reduce_mean(self._get_distribution_strategy()) + def testAllReduceMean(self, distribution): + self._test_all_reduce_mean(distribution) - def testAllReduceMeanGradients(self): - self._test_all_reduce_mean_gradients(self._get_distribution_strategy()) + def testAllReduceMeanGradients(self, distribution): + self._test_all_reduce_mean_gradients(distribution) - def testAllReduceMeanGradientTape(self): - self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy()) + def testAllReduceMeanGradientTape(self, distribution): + self._test_all_reduce_mean_gradient_tape(distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index fede253d13804087476fef8b7211a6bfe5789906..3de2041ae35775de6df5bca02c0f1d04a9c2f24e 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -696,6 +696,7 @@ class ParameterServerStrategyTest( def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( combinations.combine( @@ -704,8 +705,8 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorDistributed(self, num_gpus, use_core_strategy, - use_dataset): + def DISABLED_testMakeInputFnIteratorDistributed( + self, num_gpus, use_core_strategy, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -732,6 +733,7 @@ class ParameterServerStrategyTest( test_reinitialize=use_dataset, use_core_strategy=use_core_strategy) + # TODO(b/124344198): Re-enable after fixing this flaky test. @combinations.generate( combinations.combine( mode=['graph'], @@ -739,8 +741,8 @@ class ParameterServerStrategyTest( required_gpus=1, use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, - use_dataset): + def DISABLED_testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, + use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 69ce1141d8bea835cb959f503647900fba5f6e25..2d9d221f427422f8bbeba55c5644658af9a9a620 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -26,6 +26,7 @@ import copy from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import device_assignment as device_assignment_lib +from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional_ops from tensorflow.contrib.tpu.python.tpu import topology from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -41,6 +42,7 @@ from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device @@ -52,6 +54,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -69,11 +72,36 @@ def initialize_tpu_system(cluster_resolver=None): master = cluster_resolver.master() logging.info("Initializing the TPU system.") - session_config = config_pb2.ConfigProto(allow_soft_placement=True) - with ops.Graph().as_default(): - with session_lib.Session(config=session_config, target=master) as sess: - serialized_topology = sess.run(tpu.initialize_system()) + if context.executing_eagerly(): + # This function looks as it is for the following non-intuitive reasons. + # tpu.initialize_system creates a dummy op whose sole purpose is to trigger + # DistributedTPURewritePass. This pass actually adds real ops that + # initialize the TPU system. Thus, we can't simply run tpu.initialize_system + # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. + # The easiest way to trigger a rewrite is to run the function with + # TPUPartitionedCallOp. + @function.defun + def _tpu_init_fn(): + return tpu.initialize_system() + + # We can't call _tpu_init_fn normally (because it contains just a dummy op, + # see above) but need to define it to get it added to eager context + # and get its assigned name. + # pylint: disable=protected-access + graph_func = _tpu_init_fn._get_concrete_function_internal() + func_name = compat.as_str(graph_func._inference_function.name) + # pylint: enable=protected-access + + output = tpu_functional_ops.TPUPartitionedCall( + args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) + serialized_topology = output[0].numpy() + else: + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + serialized_topology = sess.run(tpu.initialize_system()) + logging.info("Finished initializing TPU system.") return topology.Topology(serialized=serialized_topology) @@ -133,7 +161,7 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring strategy, device_map, value_list, aggregation, logical_device=logical_device) - if not context.executing_eagerly(): + if not (context.executing_eagerly() or ops.inside_function()): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables # to the TRAINABLE_VARIABLES collection, so we manually remove @@ -155,8 +183,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def __init__(self, tpu_cluster_resolver=None, steps_per_run=None, - device_assignment=None, - **kwargs): + device_assignment=None): """Initializes the TPUStrategy object. Args: @@ -170,18 +197,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify the placement of replicas on the TPU cluster. Currently only supports the usecase of using a single core within a TPU cluster. - **kwargs: Additional experimental flags. Will be removed in future. """ - if len(kwargs) > 1: - raise ValueError("TPUStrategy constructor only takes one experimental " - "flag now") - elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs: - raise ValueError("TPUStrategy constructor does not support arguments: " - "{}".format(kwargs)) - super(TPUStrategy, self).__init__(TPUExtended( - self, tpu_cluster_resolver, steps_per_run, device_assignment, - kwargs.get("_disable_training_loop_on_host", False))) + self, tpu_cluster_resolver, steps_per_run, device_assignment)) @property def steps_per_run(self): @@ -193,13 +211,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # This implementation runs a single step. It does not use infeed or outfeed. def experimental_run(self, fn, input_iterator=None): """See base class.""" - if context.executing_eagerly(): - raise NotImplementedError("Eager mode not supported in TPUStrategy.") - - if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access + if context.executing_eagerly() and not ops.inside_function(): raise NotImplementedError( - "`experimental_run` is not compatible with " - "`_disable_training_loop_on_host=True`") + "Eager mode not supported in TPUStrategy outside TF functions.") if input_iterator is None: inputs = [] @@ -207,13 +221,13 @@ class TPUStrategy(distribute_lib.DistributionStrategy): inputs = input_iterator.get_next() result = [None] - def replicated_fn(replica_id, inputs): + def replicated_fn(replica_id, replica_input): """Wraps user function to provide replica ID and `Tensor` inputs.""" with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): if input_iterator is None: result[0] = fn() else: - result[0] = fn(inputs) + result[0] = fn(replica_input) return result[0] replicate_inputs = [] # By replica. @@ -241,8 +255,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): container_strategy, tpu_cluster_resolver=None, steps_per_run=None, - device_assignment=None, - disable_training_loop_on_host=False): + device_assignment=None): super(TPUExtended, self).__init__(container_strategy) if tpu_cluster_resolver is None: @@ -256,7 +269,6 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) self._device_assignment = device_assignment - self._disable_training_loop_on_host = disable_training_loop_on_host # Device assignment is currently only supported for 1 core case. if self._device_assignment: @@ -284,25 +296,14 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices) - # If the training loop is on the device, we must use the infeed, with input - # on the host. Otherwise, we preload the data onto the TPUs. - if disable_training_loop_on_host: - input_device_map = values.ReplicaDeviceMap(tuple( - self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) - worker_devices = [ - (self.get_host(hid), [self.get_host_cpu_device(hid)]) - for hid in range(self.num_hosts) - ] - self._input_workers = input_lib.InputWorkers( - input_device_map, worker_devices) - else: - input_worker_devices = collections.OrderedDict() - for tpu_device in self._tpu_devices: - host_device = _get_host_for_device(tpu_device) - input_worker_devices.setdefault(host_device, []) - input_worker_devices[host_device].append(tpu_device) - self._input_workers = input_lib.InputWorkers( - self._device_map, tuple(input_worker_devices.items())) + # Preload the data onto the TPUs. + input_worker_devices = collections.OrderedDict() + for tpu_device in self._tpu_devices: + host_device = _get_host_for_device(tpu_device) + input_worker_devices.setdefault(host_device, []) + input_worker_devices[host_device].append(tpu_device) + self._input_workers = input_lib.InputWorkers( + self._device_map, tuple(input_worker_devices.items())) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -402,17 +403,6 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): - if self._disable_training_loop_on_host: - impl = self._run_steps_on_iterator_with_device_loop - else: - impl = self._run_steps_on_iterator_with_host_loop - - return impl( - fn=fn, multi_worker_iterator=multi_worker_iterator, - iterations=iterations, initial_loop_values=initial_loop_values) - - def _run_steps_on_iterator_with_host_loop( - self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): @@ -507,79 +497,6 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx - def _run_steps_on_iterator_with_device_loop( - self, fn, multi_worker_iterator, iterations, initial_loop_values=None): - output_shapes = multi_worker_iterator.output_shapes - shapes = nest.flatten(output_shapes) - if any(not s.is_fully_defined() for s in shapes): - raise ValueError( - "TPU currently requires fully defined shapes. Either use " - "set_shape() on the input tensors or use " - "dataset.batch(..., drop_remainder=True).") - types = nest.flatten(multi_worker_iterator.output_types) - - enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, - iterations) - for host_id in range(self.num_hosts)] - - def dequeue_fn(): - dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(output_shapes, dequeued) - - # Wrap `fn` for repeat. - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - ctx = input_lib.MultiStepContext() - - def run_fn(*args, **kwargs): - """Single step on the TPU device.""" - del args, kwargs - fn_result = fn(ctx, dequeue_fn()) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - if flat_last_step_outputs: - with ops.control_dependencies([fn_result]): - return [array_ops.identity(f) for f in flat_last_step_outputs] - else: - return fn_result - - def iterate_on_tpu(): - return training_loop.repeat(iterations, run_fn, initial_loop_values) - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop and TPU replicate context. This is useful in cases - # where we might need to exit these contexts and get back to the outer - # context to do some things, for e.g. create an op which should be - # evaluated only once at the end of the loop on the host. One such usage - # is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - replicate_inputs = [[]] * self._num_replicas_in_sync - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) - - del self._outer_control_flow_context - ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] - - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in - zip(*last_step_tensor_outputs)] - - _set_last_step_outputs(ctx, last_step_tensor_outputs) - return ctx - def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. @@ -619,9 +536,10 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - value_list[0].value()) + if context.executing_eagerly() or ops.inside_function(): + with ops.init_scope(): + kwargs["initial_value"] = array_ops.identity( + value_list[0].value()) else: def initial_value_fn(device=d): with ops.device(device): @@ -655,19 +573,24 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): return cross_device_ops_lib.reduce_non_distributed_value( reduce_op, self._device_map, value, destinations) - # Validate that the destination is same as the host device - # Note we don't do this when in replicate context as the reduction is - # performed on the TPU device itself. devices = cross_device_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - assert device_util.canonicalize(devices[0]) == device_util.canonicalize( - self._host_device) - else: + if len(devices) != 1: raise ValueError("Multiple devices are not supported for TPUStrategy") - output = math_ops.add_n(value) - if reduce_op == reduce_util.ReduceOp.MEAN: - return output * (1. / len(value)) + # Always performs the reduction on the TPU host. + with ops.device(self._host_device): + output = math_ops.add_n(value.values) + if reduce_op == reduce_util.ReduceOp.MEAN: + output *= (1. / len(value.values)) + + # If necessary, copy to requested destination. + dest_canonical = device_util.canonicalize(devices[0]) + host_canonical = device_util.canonicalize(self._host_device) + + if dest_canonical != host_canonical: + with ops.device(devices[0]): + output = array_ops.identity(output) + return output def _update(self, var, fn, args, kwargs, group): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 51c58b0b2f3dc2ab63e22718825a471b8657f892..9fd251175b8b8e3453e33434b4d86386a078295e 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -656,7 +656,8 @@ class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, replica_local = _make_replica_local("sum", distribution) + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM, distribution) # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 8966a9befcd3db4a3f397b319e80f37f84ad236b..d441e4735b64fe1176e77a978d281d46a7b287ab 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -144,7 +144,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", - "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/tracking:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 78ab155896cfeda4dd259a8529f4b1f77a12cf0b..48925b1bfacc6b59c210b2fb4b53a9a1a851673f 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class IteratorTest(test.TestCase): @@ -238,7 +238,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual([1, 4], iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) self.assertAllEqual([9, 16], iterator.get_next().numpy()) @@ -257,7 +257,7 @@ class IteratorTest(test.TestCase): dataset_2 = Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = trackable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) self.assertEqual(0, iterator_3.get_next().numpy()) @@ -279,7 +279,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(3) iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) save_path = checkpoint.save(checkpoint_prefix) @@ -293,7 +293,7 @@ class IteratorTest(test.TestCase): dataset = Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) - checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint = trackable_utils.Checkpoint(iterator=iterator) checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) for j in range(2): diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index d18a097063c7d25947af3e2e2959ce574edd553f..3143270ccfe4f670428c80bdc1e09fa452584207 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: enable=g-bad-import-order @@ -421,7 +421,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph = checkpointable_utils.object_metadata( + object_graph = trackable_utils.object_metadata( checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c8d9266672a8b87d32338ea7c4f74fb40d41c767..b32501c2e804838af9d4c77663be131b77bd30b4 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -32,12 +32,12 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(checkpointable.Checkpointable): +class Metric(trackable.Trackable): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -269,7 +269,7 @@ class Metric(checkpointable.Checkpointable): else: collections = [ops.GraphKeys.LOCAL_VARIABLES] collections += [ops.GraphKeys.METRIC_VARIABLES] - # Variables are Checkpointable dependencies of Metrics regardless of the + # Variables are Trackable dependencies of Metrics regardless of the # global/local distinction. Users can avoid saving variables by not adding a # dependency on the Metric. v = self._add_variable_with_custom_getter( @@ -282,7 +282,7 @@ class Metric(checkpointable.Checkpointable): use_resource=True, getter=variable_scope.get_variable, # Raise duplicate variable exceptions from get_variable rather than - # Checkpointable. + # Trackable. overwrite=True) self._vars.append(v) if context.executing_eagerly(): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 39e5957f5d1760613f2c33607c0bdb163040efb4..c56d1956fde35b562e60496015e666efe9ebc8f6 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -35,7 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils class MetricsTest(test.TestCase): @@ -314,7 +314,7 @@ class MetricsTest(test.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") mean = metrics.Mean() - checkpoint = checkpointable_utils.Checkpoint(mean=mean) + checkpoint = trackable_utils.Checkpoint(mean=mean) mean.build() mean._built = True self.evaluate(mean.init_variables()) @@ -327,7 +327,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(200., self.evaluate(mean.value())) restore_mean = metrics.Mean() - restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + restore_checkpoint = trackable_utils.Checkpoint(mean=restore_mean) status = restore_checkpoint.restore(save_path) restore_update = restore_mean(300.) status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 240f213c602395b8589d39c3ecd90b602ffa9848..b3e8daddaf2369e9e33179fde2aab1469e97ea47 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.tracking import util as trackable_utils # pylint: disable=not-callable @@ -65,7 +65,7 @@ class NetworkTest(test.TestCase): def test_checkpointing_not_implemented(self): checkpoint_directory = self.get_temp_dir() - checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + checkpoint = trackable_utils.Checkpoint(net=MyNetwork()) with self.assertRaises(NotImplementedError): checkpoint.save(checkpoint_directory) diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 7803a6799bb64441fab881bf6ca986d5cf3851a8..258f0a19309235dcd99b31b4de3d35ef8d89b15b 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -30,7 +30,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): @@ -129,8 +129,8 @@ class SharedVariable(resource_variable_ops.ResourceVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, checkpointable.CheckpointInitialValue): - self._maybe_initialize_checkpointable() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b82e1bb71bce9a28d7bbbf961cc6d5e25dd18acf..df5b059448f735f7dc1f2963ffbc9c8a8287250a 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -62,7 +62,6 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@Checkpoint @@Checkpointable -@@CheckpointableSaver @@executing_eagerly @@in_eager_mode @@ -138,9 +137,8 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable -from tensorflow.python.training.checkpointable.util import CheckpointableSaver -from tensorflow.python.training.checkpointable.util import Checkpoint +from tensorflow.python.training.tracking.tracking import AutoTrackable as Checkpointable +from tensorflow.python.training.tracking.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 48a6ef4dca0ca7682f7b99b66177679f29ad9ec9..da2479a0b7b029561136903c82cabed9aae622b8 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -203,10 +203,7 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = [ - "nomac", # b/73741358 - "notsan", # b/67512932 - ], + tags = ["notsan"], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 8fc5f1cfe7800653ef1e43c6d40d1a66e34f2106..0a9199d61f36f10c98b95d79ece7e86765d2db0e 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,7 +14,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", - ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -72,60 +71,3 @@ tf_py_test( ], tags = ["no_pip"], ) - -py_library( - name = "sequence_feature_column_v2", - srcs = ["python/feature_column/sequence_feature_column_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -tf_py_test( - name = "sequence_feature_column_v2_test", - srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], - additional_deps = [ - ":sequence_feature_column_v2", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/feature_column:feature_column_py", - "//tensorflow/python/feature_column:feature_column_v2_test", - ], - tags = ["no_pip"], -) - -py_test( - name = "sequence_feature_column_v2_integration_test", - srcs = ["python/feature_column/sequence_feature_column_v2_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":sequence_feature_column_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/feature_column:feature_column_py", - "//tensorflow/python/keras:layers", - ], -) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 3f6dbe0cbdeeae5e2107755f80bcfe5f7fc310e4..8fd2b5f39bc88b76fe5583f8d18389e232ea9f40 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -32,7 +32,6 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", - "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", "python/ops/script_ops.py", @@ -172,26 +171,6 @@ py_test( ], ) -cuda_py_test( - name = "critical_section_test", - size = "medium", - srcs = ["python/ops/critical_section_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", - ], -) - py_test( name = "ops_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 3784631dcbfbeb215b6c695e4b6f1bbd02fa708c..063717f08aa88f4de9470d8392db2b7c95b3e4bf 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -94,8 +94,6 @@ @@smart_constant_value @@smart_case -@@CriticalSection - @@BoundedTensorSpec @@TensorSpec @@ -129,6 +127,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', + 'is_nested', 'is_sequence', 'is_sequence_or_composite', 'flatten', @@ -139,6 +138,7 @@ _nest_allowed_symbols = [ 'map_structure_with_tuple_paths', 'assert_shallow_structure', 'flatten_up_to', + 'flatten_with_tuple_paths_up_to', 'map_structure_up_to', 'map_structure_with_tuple_paths_up_to', 'get_traverse_shallow_structure', diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index c4976497f5fa95d82e492153b117681f693eaa13..8113bf7c095bd0817e40cfd08bdf1ef7275ba55b 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,7 +22,6 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * -from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * from tensorflow.contrib.framework.python.ops.script_ops import * diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index b6b75ffa248d66cc4cb49339f193d486f05a6a4a..f13a66717f67a1a627f66af9468c6f2897aaf7a4 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -565,6 +565,20 @@ void LaunchFusedConv2DBiasActivationOp:: fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo( stream->parent()), &algorithms)); + if (activation_mode == ActivationMode::NONE) { + // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported for + // identity activation, other algs seem to quietly do Relu. + // See + // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward + algorithms.erase( + std::remove_if( + algorithms.begin(), algorithms.end(), + [](dnn::AlgorithmDesc alg) { + return alg.algo_id() != + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + }), + algorithms.end()); + } dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index db0868fb2c43464a811b3d6dfcd96480ba2463ee..386e4cf69b7aa118a85fb25bcb809a879c5c1bd8 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -377,7 +377,10 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_pip", + "no_windows", + ], deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 5b9c54e43a16adf457d5ed0e7e73dcd168ab0d67..66af79d1e81bbc450141673dd54d865e5c7932d5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -238,10 +237,10 @@ class GANEstimatorIntegrationTest(test.TestCase): # Evaluate. scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # Predict. predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py index c00ff4399748a77f88d9753df7592bf3859d754e..0fcd1b7924eb02f5d617b45af16852baf2e2bb48 100644 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -235,10 +234,10 @@ class StarGANEstimatorIntegrationTest(test.TestCase): # EVALUTE scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py index 0d9e6489bdd1d89cc49bfedc2eed784999c31d2b..baf2c28df4b63cff525dcf3ff880730768ad000a 100644 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py @@ -23,7 +23,6 @@ import tempfile from absl.testing import parameterized import numpy as np -import six from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples @@ -184,11 +183,11 @@ class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): # Evaluate. num_steps_eval = 2 scores = est.evaluate(eval_input_fn, steps=num_steps_eval) - self.assertIn(ops.GraphKeys.GLOBAL_STEP, six.iterkeys(scores)) - self.assertIn('loss', six.iterkeys(scores)) + self.assertIn(ops.GraphKeys.GLOBAL_STEP, scores) + self.assertIn('loss', scores) self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], scores['loss']) - self.assertIn('mse_custom_metric', six.iterkeys(scores)) + self.assertIn('mse_custom_metric', scores) # Predict. predictions = np.array([x['generated_data'] for x in diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index bd17571a0535a3c8e9dfee24a8da16eb2e72f165..bc7c1057b478fe2656898e68c1a14013b5a71d12 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -365,7 +365,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) - with self.test_session(use_gpu=True) as sess: + with self.cached_session(use_gpu=True) as sess: incscore_np = sess.run(incscore, {'concat:0': logits}) self.assertAllClose(_expected_inception_score(logits), incscore_np) @@ -473,7 +473,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): classifier_fn=lambda x: x, max_block_size=600) - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op) expected_kid, expected_std = _expected_kid_and_std(test_pool_real_a, @@ -500,7 +500,7 @@ class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): max_block_size=max_block_size) for block_size in [50, 512, 1000]: - with self.test_session() as sess: + with self.cached_session() as sess: actual_kid, actual_std = sess.run(kid_op, {max_block_size: block_size}) expected_kid, expected_std = _expected_kid_and_std( diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index e3c780ac1a0f0ef15ff993bd3a9bf9730dcb45b8..44ee0f52696dc1cdcd91286a80b2d4b42be93a4d 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -403,7 +403,9 @@ class _PenaltyTest(object): def test_all_correct(self): loss = self._penalty_fn(**self._kwargs) self.assertEqual(self._expected_dtype, loss.dtype) - self.assertEqual(self._expected_op_name, loss.op.name) + # NOTE: Op names will change, it is inappropriate to include them in tests. + # See go/tf-breaking-change. + # self.assertEqual(self._expected_op_name, loss.op.name) with self.cached_session(): variables.global_variables_initializer().run() self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index f36a5d346e0f27fbbc480e876380db51ed559c09..9bff8090d93d3ad7def69726073accfb234ef301 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -757,7 +757,9 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) - +# Begin google-internal +# The four major parts can be found here: http://screen/tMRMBAohDYG. +# End google-internal def stargan_loss( model, generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( @@ -776,8 +778,6 @@ def stargan_loss( add_summaries=True): """StarGAN Loss. - The four major part can be found here: http://screen/tMRMBAohDYG. - Args: model: (StarGAN) Model output of the stargan_model() function call. generator_loss_fn: The loss function on the generator. Takes a diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc index b84710d26eb8a64bf2f86b9f920551a8a8dbb233..755cbdff31cd7ca31579e0d64399d681dc24ad81 100644 --- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -100,8 +100,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { // Logic to be executed on the RecvBufAsync callback. auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr, - to_device_ctx, to_tensor, dev_to_dev_stream_index, - done](const Status& s) { + to_device_ctx, to_tensor, done](const Status& s) { if (s.ok()) { remote_memory_manager_->TensorFromTransportOptions( to_tensor, state->call->resp_.transport_options(), to_device, diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 5f8c300155770ed03ad12a9fa5ac74456edaf024..1124dff741309d8fd04954e70c5ebaaf164b940a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -167,8 +167,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // RendezvousMgr already aborted, shouldn't send RPC call any more if (!call->status().ok()) { - done(call->status(), Args(), Args(), Tensor(), false); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(call->status(), Args(), Args(), Tensor(), false); delete call; return; } @@ -181,8 +184,11 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // If StartAbort was called prior to DeregisterCall, then the // current status should be bad. Status s = call->status(); - done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); + // NOTE: `*session()` can potentially be deleted before we return from + // `call->done()(...)`, so we must release the worker before calling the + // callback. session()->worker_cache->ReleaseWorker(src_worker, rwi); + done(s, Args(), call->recv_args(), call->tensor(), call->is_dead()); delete call; Unref(); }); diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index e79ccd8da1f8952758ae322d3a92dec34910a9db..5b37239665d46db38fc249e9004d2200abb3d610 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -22,7 +22,6 @@ from __future__ import print_function from copy import deepcopy from functools import partial from six import iteritems -from six import iterkeys from six import string_types from six import StringIO from tensorflow.contrib.graph_editor import reroute @@ -735,9 +734,8 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", # control dependencies. graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) control_ios = util.ControlOutputs(graph) - ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)), - flatten_target_ts, - control_ios=control_ios) + ops = select.get_walks_intersection_ops( + list(replacement_ts), flatten_target_ts, control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index d319aa7986d81cf9ac2d1dc2e15b053a0aa0c31b..92016e6a83975a9b15a39a15125e0eabc111912e 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -19,16 +19,25 @@ tf_cc_binary( "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:candidate_sampling_ops_op_lib", "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:data_flow_ops_op_lib", "//tensorflow/core:framework_internal", "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:io_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:list_ops_op_lib", + "//tensorflow/core:logging_ops_op_lib", + "//tensorflow/core:lookup_ops_op_lib", "//tensorflow/core:manip_ops_op_lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:parsing_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:random_ops_op_lib", "//tensorflow/core:remote_fused_graph_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:sparse_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:string_ops_op_lib", "//tensorflow/core:training_ops_op_lib", "//tensorflow/core:user_ops_op_lib", diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index c4476a7bbd5056fa898468a46031bf3d8b1e44cf..b12832d2e2a3cccb4948d9e3bf3d226030121ac2 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -22,7 +22,7 @@ from __future__ import print_function from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import categorical_hinge -from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import cosine_similarity from tensorflow.python.keras.losses import hinge from tensorflow.python.keras.losses import kullback_leibler_divergence from tensorflow.python.keras.losses import logcosh diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 7317fdb52c5b79e787a49d71be49f5261d6b1fff..095b5d798df9ac9038fa1088cdd402dff304e87e 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -23,7 +23,7 @@ from tensorflow.python.keras.metrics import binary_accuracy from tensorflow.python.keras.metrics import binary_crossentropy from tensorflow.python.keras.metrics import categorical_accuracy from tensorflow.python.keras.metrics import categorical_crossentropy -from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import cosine_similarity from tensorflow.python.keras.metrics import hinge from tensorflow.python.keras.metrics import kullback_leibler_divergence from tensorflow.python.keras.metrics import mean_absolute_error diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 403b522ce45ac6ad98a321378626b87aaa7738aa..9d9524e4e4b995d795b7c71b5bd083d11c60d5ce 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2308,7 +2308,7 @@ def layer_norm(inputs, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) - # Calculate the moments on the last axis (layer activations). + # By default, compute the moments across all the dimensions except the one with index 0. norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 7c2d9bb0767cb979dae9c84b5342d129225677ed..a52d25acf402bdda46771e9146a40cfb71e99d53 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -62,8 +62,8 @@ def _assert_no_variables(test_case): def _assert_metrics(test_case, expected_loss, expected_eval_metrics, model_fn_ops): test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4) - for k in six.iterkeys(expected_eval_metrics): - test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops)) + for k in expected_eval_metrics: + test_case.assertIn(k, model_fn_ops.eval_metric_ops) variables.initialize_local_variables().run() for key, expected_value in six.iteritems(expected_eval_metrics): value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key] @@ -545,19 +545,19 @@ class MultiLabelHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.731059, 0.5, 0.5], [0.5, 0.5, 0.731059,]], @@ -850,18 +850,18 @@ class BinaryClassificationHeadTest(test.TestCase): with session.Session(): self.assertListEqual( [1, 1], list(model_fn_ops.predictions["classes"].eval())) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.LOGISTIC_REGRESSION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) predicted_classes = predictions_for_serving["classes"].eval().tolist() self.assertListEqual( [b"0", b"1"], predicted_classes[0]) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) def testBinaryClassificationInferMode_withWeightColumn(self): n_classes = 2 @@ -1349,18 +1349,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [0, 2], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], @@ -1401,18 +1401,18 @@ class MultiClassHeadTest(test.TestCase): self.assertAllEqual( [b"key0", b"key2"], model_fn_ops.predictions["classes"].eval()) - self.assertItemsEqual( - ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) + self.assertItemsEqual(["head_name"], + list(model_fn_ops.output_alternatives)) self.assertEqual( constants.ProblemType.CLASSIFICATION, model_fn_ops.output_alternatives["head_name"][0]) predictions_for_serving = ( model_fn_ops.output_alternatives["head_name"][1]) - self.assertIn("classes", six.iterkeys(predictions_for_serving)) + self.assertIn("classes", predictions_for_serving) self.assertAllEqual( [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]], predictions_for_serving["classes"].eval()) - self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) + self.assertIn("probabilities", predictions_for_serving) self.assertAllClose( [[0.576117, 0.2119416, 0.2119416], [0.2119416, 0.2119416, 0.576117]], diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index a28394964a12013c43d85701b5a0ab5c559afd62..8fda828e994bc2436eaba4475077020436703631 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -# TODO(rohanj): This should subclass Checkpointable and implement +# TODO(rohanj): This should subclass Trackable and implement # _gather_saveables_for_checkpoint. class ShardedMutableDenseHashTable(object): """A sharded version of MutableDenseHashTable. diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 591eabc66c49f301cf73cd912ebbef70cc9e1e3f..9fe8dafcc8edd6b80625c61a4a0e783e65b44720 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1483,3 +1483,4 @@ class IdTableWithHashBucketsTest(test.TestCase): if __name__ == "__main__": test.main() + diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index 728f75f8ef1eb3b107dbd0ab4ffbecd63787bf3e..f4ebbdeee883ddeef0d47cb561901c16e2195bb2 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -82,10 +82,11 @@ py_library( py_test( name = "metric_loss_ops_test", - size = "large", + size = "medium", srcs = [ "python/metric_learning/metric_loss_ops_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":metric_learning_py", diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 9ea94c74330e3e49414a6a84cd5bc0db3778114a..0a0ba36232075460b561bc54a95fc24973017571 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -40,7 +40,6 @@ tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/platform/default/logging.cc tensorflow/core/platform/default/mutex.cc -tensorflow/core/platform/default/protobuf.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc @@ -53,6 +52,7 @@ tensorflow/core/platform/posix/error.cc tensorflow/core/platform/posix/load_library.cc tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc +tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc tensorflow/core/platform/setround.cc tensorflow/core/platform/tensor_coding.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 8330c45cc16ffa536107e25699379bb5d9e8993b..1c1460ce77c99d29785c7e8b8a8e9f770a45b59f 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.cc tensorflow/core/framework/versions.pb.cc tensorflow/core/grappler/costs/op_performance_data.pb.cc tensorflow/core/lib/core/error_codes.pb.cc +tensorflow/core/protobuf/trackable_object_graph.pb.cc tensorflow/core/protobuf/cluster.pb.cc tensorflow/core/protobuf/config.pb.cc tensorflow/core/protobuf/eager_service.pb.cc @@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.cc tensorflow/core/protobuf/named_tensor.pb.cc tensorflow/core/protobuf/queue_runner.pb.cc tensorflow/core/protobuf/rewriter_config.pb.cc +tensorflow/core/protobuf/saved_object_graph.pb.cc tensorflow/core/protobuf/saver.pb.cc +tensorflow/core/protobuf/struct.pb.cc tensorflow/core/protobuf/tensorflow_server.pb.cc tensorflow/core/protobuf/verifier_config.pb.cc tensorflow/core/util/event.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 7257ac8feedfb8ed18c4d691cd85766e70a48ae8..5def632e8a7b65272a1339bdacd92c1fa23012d2 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/variable.pb.h tensorflow/core/framework/versions.pb.h tensorflow/core/grappler/costs/op_performance_data.pb.h tensorflow/core/lib/core/error_codes.pb.h +tensorflow/core/protobuf/trackable_object_graph.pb.h tensorflow/core/protobuf/cluster.pb.h tensorflow/core/protobuf/config.pb.h tensorflow/core/protobuf/debug.pb.h @@ -34,7 +35,9 @@ tensorflow/core/protobuf/meta_graph.pb.h tensorflow/core/protobuf/named_tensor.pb.h tensorflow/core/protobuf/queue_runner.pb.h tensorflow/core/protobuf/rewriter_config.pb.h +tensorflow/core/protobuf/saved_object_graph.pb.h tensorflow/core/protobuf/saver.pb.h +tensorflow/core/protobuf/struct.pb.h tensorflow/core/protobuf/tensor_bundle.pb.h tensorflow/core/protobuf/tensorflow_server.pb.h tensorflow/core/protobuf/verifier_config.pb.h diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index 24d86d313b76343ed9450a33cf185d9c426696bb..deb6a5b94020a02b878bdd68a33b3737a97fcf2b 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -31,6 +31,7 @@ tensorflow/core/framework/versions.proto tensorflow/core/grappler/costs/op_performance_data.proto tensorflow/core/kernels/boosted_trees/boosted_trees.proto tensorflow/core/lib/core/error_codes.proto +tensorflow/core/protobuf/trackable_object_graph.proto tensorflow/core/protobuf/cluster.proto tensorflow/core/protobuf/config.proto tensorflow/core/protobuf/debug.proto @@ -40,7 +41,9 @@ tensorflow/core/protobuf/meta_graph.proto tensorflow/core/protobuf/named_tensor.proto tensorflow/core/protobuf/queue_runner.proto tensorflow/core/protobuf/rewriter_config.proto +tensorflow/core/protobuf/saved_object_graph.proto tensorflow/core/protobuf/saver.proto +tensorflow/core/protobuf/struct.proto tensorflow/core/protobuf/tensor_bundle.proto tensorflow/core/protobuf/tensorflow_server.proto tensorflow/core/protobuf/verifier_config.proto diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 63843b993c16363a80b64622af665aaa64e05830..93701249cc8bf722c8c8558e91e0b700ca1c4a04 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -10,6 +10,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -45,6 +46,28 @@ tf_gen_op_wrapper_py( deps = [":memory_stats_ops_op_lib"], ) +tf_gen_op_wrapper_cc( + name = "memory_stats_ops", + out_ops_file = "memory_stats_ops", +) + +cc_library( + name = "memory_stats_cc", + srcs = ["memory_stats_ops.cc"], + hdrs = ["memory_stats_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":memory_stats_kernels", + ":memory_stats_ops_op_lib", + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "memory_stats_py", srcs = [ diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 974fb537499c5ea4591a0a128f53d2dea67b9e57..7ae1dbeaa2d04d7846e7fada117f3941319cc1c1 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -24,13 +24,15 @@ class MemoryStatsOp : public OpKernel { void Compute(OpKernelContext* context) override { Allocator* allocator = context->device()->GetAllocator(AllocatorAttributes()); - AllocatorStats allocator_stats; - allocator->GetStats(&allocator_stats); + absl::optional allocator_stats = allocator->GetStats(); + if (!allocator_stats) { + *allocator_stats = AllocatorStats(); + } Tensor* output_tensor = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape({}), &output_tensor)); - output_tensor->scalar()() = ExtractAllocatorStats(allocator_stats); + output_tensor->scalar()() = ExtractAllocatorStats(*allocator_stats); } protected: @@ -71,7 +73,7 @@ class BytesLimitOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.bytes_limit; + return allocator_stats.bytes_limit ? *allocator_stats.bytes_limit : -1; } }; @@ -93,7 +95,7 @@ class MaxBytesInUseOp : public MemoryStatsOp { private: int64 ExtractAllocatorStats( const AllocatorStats& allocator_stats) const override { - return allocator_stats.max_bytes_in_use; + return allocator_stats.peak_bytes_in_use; } }; diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h index cae57ce60eb09509af69f8ccab9eacedea361548..9b5d52e1b648e62af93d5420885e4f22796e3ea1 100644 --- a/tensorflow/contrib/mpi_collectives/ring.h +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -129,7 +129,7 @@ cudaStream_t CudaStreamForMPI(); * has the fully accumulated Segment 1; and so on. The scatter-reduce is * complete. * - * Next, the allgather distributes these fully accumululated chunks across all + * Next, the allgather distributes these fully accumulated chunks across all * nodes. Communication proceeds in the same ring, once again in N-1 steps. At * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). * For example, at the first iteration, the following transfers will occur: diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 12320d9e456ae93cbf95639a0c9e0c7f414f3518..f30643cf3059754daaeee4093938ac47b26f76ea 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -413,8 +413,9 @@ py_test( py_test( name = "shampoo_test", - size = "large", + size = "medium", srcs = ["python/training/shampoo_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":opt_py", diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 0243927ce44aec626973744507e75b20a42253e9..b2ea3daf82ed8daa6e0b9acd8e3cf258b8181615 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -44,14 +44,15 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import tracking -from tensorflow.python.training.checkpointable import util +from tensorflow.python.training.tracking import graph_view +from tensorflow.python.training.tracking import tracking +from tensorflow.python.training.tracking import util -class NonLayerCheckpointable(tracking.AutoCheckpointable): +class NonLayerTrackable(tracking.AutoTrackable): def __init__(self): - super(NonLayerCheckpointable, self).__init__() + super(NonLayerTrackable, self).__init__() self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -64,8 +65,8 @@ class MyModel(training.Model): super(MyModel, self).__init__() self._named_dense = core.Dense(1, use_bias=True) self._second = core.Dense(1, use_bias=False) - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() + # We can still track Trackables which aren't Layers. + self._non_layer = NonLayerTrackable() def call(self, values): ret = self._second(self._named_dense(values)) @@ -100,7 +101,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -116,11 +117,10 @@ class CheckpointingTests(test.TestCase): other_model(input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) - named_variables, serialized_graph, _ = ( - util._serialize_object_graph( - root_checkpointable, saveables_cache=None)) + named_variables, serialized_graph, _ = graph_view.ObjectGraphView( + root_trackable).serialize_object_graph() expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -208,7 +208,7 @@ class CheckpointingTests(test.TestCase): def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -217,24 +217,24 @@ class CheckpointingTests(test.TestCase): else: train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. - root_checkpointable.save_counter # pylint: disable=pointless-statement + root_trackable.save_counter # pylint: disable=pointless-statement self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.])) m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - save_path = root_checkpointable.save(file_prefix=prefix) + save_path = root_trackable.save(file_prefix=prefix) self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + self.evaluate(state_ops.assign(root_trackable.save_counter, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration - status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status = root_trackable.restore(save_path=save_path).assert_consumed() status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual(1, self.evaluate(root_trackable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly @@ -440,7 +440,7 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = tracking.AutoCheckpointable() + root = util.Checkpoint() root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -455,21 +455,17 @@ class CheckpointingTests(test.TestCase): util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "no_slots")) + no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = util.CheckpointableSaver(root).save( - os.path.join(checkpoint_directory, "with_slots")) - new_root = tracking.AutoCheckpointable() + slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) + new_root = util.Checkpoint() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = util.CheckpointableSaver( - new_root).restore(slots_path) - no_slot_status = util.CheckpointableSaver( - new_root).restore(no_slots_path) + slot_status = new_root.restore(slots_path) + no_slot_status = new_root.restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = util.add_variable( @@ -508,15 +504,14 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.AutoCheckpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) before_ops = graph.get_operations() - saver.save(checkpoint_prefix) + obj.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) def testManyRestoresGraph(self): @@ -526,16 +521,15 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = tracking.AutoCheckpointable() + obj = util.Checkpoint() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(util.gather_initializers(obj)) - saver = util.CheckpointableSaver(obj) - save_path = saver.save(checkpoint_prefix) - saver.restore(save_path) + save_path = obj.save(checkpoint_prefix) + obj.restore(save_path) before_ops = graph.get_operations() - saver.restore(save_path) + obj.restore(save_path) self.assertEqual(before_ops, graph.get_operations()) def testMultipleGraphsNonSlotVariables(self): @@ -548,11 +542,11 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = util.Checkpoint( + first_root_trackable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) self.evaluate(util.gather_initializers( - first_root_checkpointable)) + first_root_trackable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) self.evaluate(optimizer.get_slot( @@ -564,23 +558,23 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = util.Checkpoint( + second_root_trackable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) - second_root_checkpointable.restore(None).initialize_or_restore() + second_root_trackable.restore(None).initialize_or_restore() self.evaluate(train_op) self.evaluate(second_variable.assign([4.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([5.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(6.)) - save_path = second_root_checkpointable.save(checkpoint_prefix) + save_path = second_root_trackable.save(checkpoint_prefix) self.evaluate(second_variable.assign([7.])) self.evaluate(optimizer.get_slot( var=second_variable, name="m").assign([8.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(6., self.evaluate(beta_1_power)) - status = second_root_checkpointable.restore(save_path) + status = second_root_trackable.restore(save_path) status.assert_consumed().run_restore_ops() self.assertAllEqual([4.], self.evaluate(second_variable)) self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( @@ -600,7 +594,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): @test_util.run_in_graph_and_eager_modes - def test_checkpointable_save_restore(self): + def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( @@ -647,13 +641,13 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = util.Checkpoint( + root_trackable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) self.evaluate(util.gather_initializers( - root_checkpointable)) + root_trackable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable # with known values to check when loading. @@ -662,24 +656,24 @@ class CheckpointCompatibilityTests(test.TestCase): var=model._named_dense.bias, name="m").assign([2.])) beta_1_power, _ = optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(3.)) - return root_checkpointable + return root_trackable - def _set_sentinels(self, root_checkpointable): - self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + def _set_sentinels(self, root_trackable): + self.evaluate(root_trackable.model._named_dense.bias.assign([101.])) self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m") + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m") .assign([102.])) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.evaluate(beta_1_power.assign(103.)) - def _check_sentinels(self, root_checkpointable): + def _check_sentinels(self, root_trackable): self.assertAllEqual( - [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + [1.], self.evaluate(root_trackable.model._named_dense.bias)) self.assertAllEqual([2.], self.evaluate( - root_checkpointable.optimizer.get_slot( - var=root_checkpointable.model._named_dense.bias, name="m"))) - beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + root_trackable.optimizer.get_slot( + var=root_trackable.model._named_dense.bias, name="m"))) + beta_1_power, _ = root_trackable.optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta_1_power)) def _write_name_based_checkpoint(self): @@ -704,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = util.CheckpointableSaver(root) + object_saver = util.TrackableSaver(graph_view.ObjectGraphView(root)) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 1323ed014c9e51e273491694fa44a8e36cc723d0..a7f978634ed45012144b2cc49ed069f6fca44f66 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,7 +24,6 @@ import abc import six -from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop @@ -39,7 +38,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -224,7 +223,7 @@ class _OptimizerV2State(object): } self._slots = {} self._non_slot_dict = {} - # Extra state to help Optimizers implement Checkpointable. Holds information + # Extra state to help Optimizers implement Trackable. Holds information # about variables which will be restored as soon as they're created. self._deferred_dependencies = {} # Non-slot variables self._deferred_slot_restorations = {} # Slot variables @@ -367,8 +366,8 @@ class _OptimizerV2State(object): slot variable needs to be restored). Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. optional_op_name: Name to use when scoping the Variable that needs to be @@ -386,7 +385,7 @@ class _OptimizerV2State(object): # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = checkpointable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( var=variable, @@ -661,7 +660,7 @@ class OptimizerV2(optimizer_v1.Optimizer): name=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Add operations to minimize `loss` by updating `var_list`. This method simply combines calls `compute_gradients()` and @@ -685,8 +684,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: An Operation that updates the variables in `var_list`. If `global_step` @@ -732,7 +730,7 @@ class OptimizerV2(optimizer_v1.Optimizer): aggregation_method=None, grad_loss=None, stop_gradients=None, - scale_loss_by_num_replicas=None): + scale_loss_by_num_replicas=False): """Compute gradients of `loss` for the variables in `var_list`. This is the first part of `minimize()`. It returns a list @@ -756,8 +754,7 @@ class OptimizerV2(optimizer_v1.Optimizer): stop_gradients: Optional. A Tensor or list of tensors not to differentiate through. scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down - by the number of replicas. By default, auto-detects whether this is - needed. + by the number of replicas. DEPRECATED and generally no longer needed. Returns: A list of (gradient, variable) pairs. Variable is always present, but @@ -781,9 +778,7 @@ class OptimizerV2(optimizer_v1.Optimizer): tape.watch(var_list) loss_value = loss() - # Scale loss for number of replicas (callable-loss case). In this case, - # we have to be careful to call distribute_lib.get_loss_reduction() - # *after* loss() is evaluated, so we know what loss reduction it uses. + # Scale loss for number of replicas (callable-loss case). loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: @@ -839,9 +834,6 @@ class OptimizerV2(optimizer_v1.Optimizer): @staticmethod def _scale_loss(loss_value, scale_loss_by_num_replicas): """Scale loss for the number of replicas.""" - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) if scale_loss_by_num_replicas: num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync if num_replicas > 1: @@ -1267,10 +1259,10 @@ class OptimizerV2(optimizer_v1.Optimizer): return self._per_graph_state.get(var._graph_key, None) # -------------- - # Overridden methods from Checkpointable. + # Overridden methods from Trackable. # -------------- - def _track_checkpointable(self, *args, **kwargs): + def _track_trackable(self, *args, **kwargs): """Optimizers may not track dependencies. Raises an error.""" raise NotImplementedError( "Optimizers may not have dependencies. File a feature request if this " @@ -1278,7 +1270,7 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _checkpoint_dependencies(self): - """From Checkpointable. Gather graph-specific non-slot variables to save.""" + """From Trackable. Gather graph-specific non-slot variables to save.""" current_graph_non_slot_variables = [] state = self._get_per_graph_state() if state is not None: @@ -1287,14 +1279,14 @@ class OptimizerV2(optimizer_v1.Optimizer): # Avoid comparing variables key=lambda item: item[0]): current_graph_non_slot_variables.append( - checkpointable.CheckpointableReference( + trackable.TrackableReference( name=name, ref=variable_object)) # Note: ignores super(); Optimizers may not have any dependencies outside of # state objects. return current_graph_non_slot_variables def _lookup_dependency(self, name): - """From Checkpointable. Find a non-slot variable in the current graph.""" + """From Trackable. Find a non-slot variable in the current graph.""" state = self._get_per_graph_state() if state is None: return None @@ -1303,10 +1295,10 @@ class OptimizerV2(optimizer_v1.Optimizer): @property def _deferred_dependencies(self): - """Lets Checkpointable know where non-slot variables are created. + """Lets Trackable know where non-slot variables are created. If necessary, creates a new state object for the current default graph. - Checkpointable will then add entries to that state's deferred dependency + Trackable will then add entries to that state's deferred dependency dictionary. The state object will check that dictionary when creating non-slot variables, restoring their value if an entry is found. @@ -1319,14 +1311,14 @@ class OptimizerV2(optimizer_v1.Optimizer): def _create_or_restore_slot_variable(self, slot_variable_position, slot_name, variable): - """Checkpointable: Restore a slot variable's value, possibly creating it. + """Trackable: Restore a slot variable's value, possibly creating it. Called when a variable which has an associated slot variable is created or restored. Args: - slot_variable_position: A `checkpointable._CheckpointPosition` object - indicating the slot variable `Checkpointable` object to be restored. + slot_variable_position: A `trackable._CheckpointPosition` object + indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. """ diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index e0c6da00d86fe4c5f881bcab7b444182da092b8f..a70f748fad60c6467946225ad5035caaf89c2aaf 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -454,7 +454,7 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), - data_format=layer_op.get_attr('data_format'), + data_format=layer_op.get_attr('data_format').decode(), name=new_layer_name) elif layer_op.type == 'MatMul': return math_ops.matmul( @@ -867,7 +867,7 @@ class _OpCloner(object): strides=op.get_attr('strides'), padding=op.get_attr('padding'), use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), - data_format=op.get_attr('data_format'), + data_format=op.get_attr('data_format').decode(), name=new_name).op def _CloneDepthwiseConv2d(self, op, inputs, new_name): diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index d65d80df8073ef70d591c4ae2af99132f1c318ef..24fa740d24502a28cb42c994715d09180ee99899 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -102,26 +102,6 @@ cuda_py_tests( xla_enabled = True, ) -cuda_py_tests( - name = "core_rnn_cell_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_cell_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@absl_py//absl/testing:parameterized", - ], -) - cuda_py_tests( name = "rnn_test", size = "medium", @@ -144,32 +124,6 @@ cuda_py_tests( ], ) -cuda_py_tests( - name = "core_rnn_test", - size = "medium", - srcs = ["python/kernel_tests/core_rnn_test.py"], - additional_deps = [ - ":rnn_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:rnn", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - ], - shard_count = 10, -) - tf_py_test( name = "fused_rnn_cell_test", size = "medium", diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py deleted file mode 100644 index a70e806211c644c703f49610414854fe3e16a9b7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ /dev/null @@ -1,1241 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for RNN cells.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import rnn as contrib_rnn -from tensorflow.contrib.rnn.python.ops import core_rnn_cell -from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util -from tensorflow.python.keras import layers as keras_layers -from tensorflow.python.layers import base as base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import rnn -from tensorflow.python.ops import rnn_cell_impl -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import util as checkpointable_utils - -# pylint: enable=protected-access -Linear = core_rnn_cell._Linear # pylint: disable=invalid-name - - -class RNNCellTest(test.TestCase, parameterized.TestCase): - - def testLinear(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(1.0)): - x = array_ops.zeros([1, 2]) - l = Linear([x], 2, False)([x]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([l], {x.name: np.array([[1., 2.]])}) - self.assertAllClose(res[0], [[3.0, 3.0]]) - - # Checks prevent you from accidentally creating a shared function. - with self.assertRaises(ValueError): - l1 = Linear([x], 2, False)([x]) - - # But you can create a new one in a new scope and share the variables. - with variable_scope.variable_scope("l1") as new_scope: - l1 = Linear([x], 2, False)([x]) - with variable_scope.variable_scope(new_scope, reuse=True): - Linear([l1], 2, False)([l1]) - self.assertEqual(len(variables_lib.trainable_variables()), 2) - - def testBasicRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testBasicRNNCellNotTrainable(self): - with self.cached_session() as sess: - - def not_trainable_getter(getter, *args, **kwargs): - kwargs["trainable"] = False - return getter(*args, **kwargs) - - with variable_scope.variable_scope( - "root", - initializer=init_ops.constant_initializer(0.5), - custom_getter=not_trainable_getter): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = rnn_cell_impl.BasicRNNCell(2) - g, _ = cell(x, m) - self.assertFalse(cell.trainable_variables) - self.assertEqual([ - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.non_trainable_variables]) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testIndRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - cell = contrib_rnn_cell.IndRNNCell(2) - g, _ = cell(x, m) - self.assertEqual([ - "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME - ], [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test GRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = rnn_cell_impl.GRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.156736, 0.156736]]) - - def testIndyGRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.185265, 0.17704]]) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyGRUCell with input_size != num_units. - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.155127, 0.157328]]) - - def testSRUCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.509682, 0.509682]]) - - def testSRUCellKerasRNN(self): - """Tests that SRUCell works with keras RNN layer.""" - cell = contrib_rnn_cell.SRUCell(10) - seq_input = ops.convert_to_tensor( - np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) - rnn_layer = keras_layers.RNN(cell=cell) - rnn_outputs_keras = rnn_layer(seq_input) - with self.cached_session() as sess: - sess.run([variables_lib.global_variables_initializer()]) - self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) - - def testSRUCellBiasType(self): - """Tests that the bias' dtype is properly set.""" - cell = contrib_rnn_cell.SRUCell(10) - cell.build((2, 3, 5)) - self.assertEqual(cell._bias.dtype, dtypes.float32_ref) - - cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) - cell.build((2, 3, 5)) - self.assertEqual(cell._bias.dtype, dtypes.int32_ref) - - cell_input = ops.convert_to_tensor( - np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) - cell_state = ops.convert_to_tensor( - np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) - cell = contrib_rnn_cell.SRUCell(10) - cell(cell_input, [cell_state]) - self.assertEqual(cell._bias.dtype, dtypes.float16_ref) - - def testSRUCellWithDiffSize(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 2]) - g, _ = contrib_rnn_cell.SRUCell(2)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - # Smoke test - self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) - - def testBasicLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - m = array_ops.zeros([1, 8], dtype=dtype) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=False) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, out_m = cell(x, m) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, out_m], { - x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8]) - }) - self.assertEqual(len(res), 2) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # The numbers in results were not calculated, this is just a - # smoke test. - self.assertAllClose(res[0], np.array( - [[0.240, 0.240]], dtype=np_dtype), 1e-2) - expected_mem = np.array( - [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], - dtype=np_dtype) - self.assertAllClose(res[1], expected_mem, 1e-2) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test BasicLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - m = array_ops.zeros([1, 4], dtype=dtype) - g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - m.name: 0.1 * np.ones([1, 4], dtype=np_dtype) - }) - self.assertEqual(len(res), 2) - - def testBasicLSTMCellDimension0Error(self): - """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size - 1, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size]) - }) - - def testBasicLSTMCellStateSizeError(self): - """Tests that state_size must be num_units * 2.""" - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - num_units = 2 - state_size = num_units * 3 # state_size must be num_units * 2 - batch_size = 3 - input_size = 4 - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - with self.assertRaises(ValueError): - g, out_m = rnn_cell_impl.BasicLSTMCell( - num_units, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - sess.run( - [g, out_m], { - x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size, state_size]) - }) - - def testBasicLSTMCellStateTupleType(self): - with self.cached_session(): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = (array_ops.zeros([1, 2]),) * 2 - m1 = (array_ops.zeros([1, 2]),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], - state_is_tuple=True) - self.assertTrue(isinstance(cell.state_size, tuple)) - self.assertTrue( - isinstance(cell.state_size[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue( - isinstance(cell.state_size[1], rnn_cell_impl.LSTMStateTuple)) - - # Pass in regular tuples - _, (out_m0, out_m1) = cell(x, (m0, m1)) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - # Pass in LSTMStateTuples - variable_scope.get_variable_scope().reuse_variables() - zero_state = cell.zero_state(1, dtypes.float32) - self.assertTrue(isinstance(zero_state, tuple)) - self.assertTrue(isinstance(zero_state[0], rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(zero_state[1], rnn_cell_impl.LSTMStateTuple)) - _, (out_m0, out_m1) = cell(x, zero_state) - self.assertTrue(isinstance(out_m0, rnn_cell_impl.LSTMStateTuple)) - self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) - - def testBasicLSTMCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m0 = array_ops.zeros([1, 4]) - m1 = array_ops.zeros([1, 4]) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=True) - g, (out_m0, out_m1) = cell(x, (m0, m1)) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m0, out_m1], { - x.name: np.array([[1., 1.]]), - m0.name: 0.1 * np.ones([1, 4]), - m1.name: 0.1 * np.ones([1, 4]) - }) - self.assertEqual(len(res), 3) - # The numbers in results were not calculated, this is just a smoke test. - # Note, however, these values should match the original - # version having state_is_tuple=False. - self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) - expected_mem0 = np.array( - [[0.68967271, 0.68967271, 0.44848421, 0.44848421]]) - expected_mem1 = np.array( - [[0.39897051, 0.39897051, 0.24024698, 0.24024698]]) - self.assertAllClose(res[1], expected_mem0) - self.assertAllClose(res[2], expected_mem1) - - def testIndyLSTMCell(self): - for dtype in [dtypes.float16, dtypes.float32]: - np_dtype = dtype.as_numpy_dtype - with self.session(graph=ops.Graph()) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2], dtype=dtype) - state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - cell = rnn_cell_impl.MultiRNNCell( - [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) - self.assertEqual(cell.dtype, None) - self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) - self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) - cell.get_config() # Should not throw an error - g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) - # Layer infers the input type. - self.assertEqual(cell.dtype, dtype.name) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual(expected_variable_names, - [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state_0, out_state_1], { - x.name: np.array([[1., 1.]]), - state_0[0].name: 0.1 * np.ones([1, 2]), - state_0[1].name: 0.1 * np.ones([1, 2]), - state_1[0].name: 0.1 * np.ones([1, 2]), - state_1[1].name: 0.1 * np.ones([1, 2]), - }) - self.assertEqual(len(res), 3) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # Only check the range of outputs as this is just a smoke test. - self.assertAllInRange(res[0], -1.0, 1.0) - self.assertAllInRange(res[1], -1.0, 1.0) - self.assertAllInRange(res[2], -1.0, 1.0) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - # Test IndyLSTMCell with input_size != num_units. - x = array_ops.zeros([1, 3], dtype=dtype) - state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 - g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_state], { - x.name: np.array([[1., 1., 1.]], dtype=np_dtype), - state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), - }) - self.assertEqual(len(res), 2) - - def testLSTMCell(self): - with self.cached_session() as sess: - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - output, state = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [output, state], { - x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]), - m.name: 0.1 * np.ones((batch_size, state_size)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1].shape, (batch_size, state_size)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) > 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) - - def testLSTMCellVariables(self): - with self.cached_session(): - num_units = 8 - num_proj = 6 - state_size = num_units + num_proj - batch_size = 3 - input_size = 2 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - m = array_ops.zeros([batch_size, state_size]) - cell = rnn_cell_impl.LSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - state_is_tuple=False) - cell(x, m) # Execute to create variables - variables = variables_lib.global_variables() - self.assertEquals(variables[0].op.name, "root/lstm_cell/kernel") - self.assertEquals(variables[1].op.name, "root/lstm_cell/bias") - self.assertEquals(variables[2].op.name, - "root/lstm_cell/projection/kernel") - - def testLSTMCellLayerNorm(self): - with self.cached_session() as sess: - num_units = 2 - num_proj = 3 - batch_size = 1 - input_size = 4 - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([batch_size, input_size]) - c = array_ops.zeros([batch_size, num_units]) - h = array_ops.zeros([batch_size, num_proj]) - state = rnn_cell_impl.LSTMStateTuple(c, h) - cell = contrib_rnn_cell.LayerNormLSTMCell( - num_units=num_units, - num_proj=num_proj, - forget_bias=1.0, - layer_norm=True, - norm_gain=1.0, - norm_shift=0.0) - g, out_m = cell(x, state) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], { - x.name: np.ones((batch_size, input_size)), - c.name: 0.1 * np.ones((batch_size, num_units)), - h.name: 0.1 * np.ones((batch_size, num_proj)) - }) - self.assertEqual(len(res), 2) - # The numbers in results were not calculated, this is mostly just a - # smoke test. - self.assertEqual(res[0].shape, (batch_size, num_proj)) - self.assertEqual(res[1][0].shape, (batch_size, num_units)) - self.assertEqual(res[1][1].shape, (batch_size, num_proj)) - # Different inputs so different outputs and states - for i in range(1, batch_size): - self.assertTrue( - float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) - self.assertTrue( - float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testWrapperCheckpointing(self): - for wrapper_type in [ - rnn_cell_impl.DropoutWrapper, - rnn_cell_impl.ResidualWrapper, - lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) - - def testOutputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.OutputProjectionWrapper(rnn_cell_impl.GRUCell(3), 2) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.231907, 0.231907]]) - - def testInputProjectionWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 3]) - cell = contrib_rnn.InputProjectionWrapper( - rnn_cell_impl.GRUCell(3), num_proj=3) - g, new_m = cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 3)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) - - @parameterized.parameters( - [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testResidualWrapper(self, wrapper_type): - x = ops.convert_to_tensor(np.array([[1., 1., 1.]])) - m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) - base_cell = rnn_cell_impl.GRUCell( - 3, kernel_initializer=init_ops.constant_initializer(0.5), - bias_initializer=init_ops.constant_initializer(0.5)) - g, m_new = base_cell(x, m) - wrapper_object = wrapper_type(base_cell) - (name, dep), = wrapper_object._checkpoint_dependencies - wrapper_object.get_config() # Should not throw an error - self.assertIs(dep, base_cell) - self.assertEqual("cell", name) - - g_res, m_new_res = wrapper_object(x, m) - self.evaluate([variables_lib.global_variables_initializer()]) - res = self.evaluate([g, g_res, m_new, m_new_res]) - # Residual connections - self.assertAllClose(res[1], res[0] + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res[2], res[3]) - - @parameterized.parameters( - [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testResidualWrapperWithSlice(self, wrapper_type): - x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]])) - m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]])) - base_cell = rnn_cell_impl.GRUCell( - 3, kernel_initializer=init_ops.constant_initializer(0.5), - bias_initializer=init_ops.constant_initializer(0.5)) - g, m_new = base_cell(x, m) - - def residual_with_slice_fn(inp, out): - inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) - return inp_sliced + out - - g_res, m_new_res = wrapper_type( - base_cell, residual_with_slice_fn)(x, m) - self.evaluate([variables_lib.global_variables_initializer()]) - res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate( - [g, g_res, m_new, m_new_res]) - # Residual connections - self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) - # States are left untouched - self.assertAllClose(res_m_new, res_m_new_res) - - def testDeviceWrapper(self): - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 3]) - m = array_ops.zeros([1, 3]) - wrapped = rnn_cell_impl.GRUCell(3) - cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") - (name, dep), = cell._checkpoint_dependencies - cell.get_config() # Should not throw an error - self.assertIs(dep, wrapped) - self.assertEqual("cell", name) - - outputs, _ = cell(x, m) - self.assertTrue("cpu:14159" in outputs.device.lower()) - - def _retrieve_cpu_gpu_stats(self, run_metadata): - cpu_stats = None - gpu_stats = None - step_stats = run_metadata.step_stats - for ds in step_stats.dev_stats: - if "cpu:0" in ds.device[-5:].lower(): - cpu_stats = ds.node_stats - if "gpu:0" == ds.device[-5:].lower(): - gpu_stats = ds.node_stats - return cpu_stats, gpu_stats - - def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): - if not test.is_gpu_available(): - # Can't perform this test w/o a GPU - return - - gpu_dev = test.gpu_device_name() - with self.session(use_gpu=True) as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1, 3]) - cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), gpu_dev) - with ops.device("/cpu:0"): - outputs, _ = rnn.dynamic_rnn( - cell=cell, inputs=x, dtype=dtypes.float32) - run_metadata = config_pb2.RunMetadata() - opts = config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE) - - sess.run([variables_lib.global_variables_initializer()]) - _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - - cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) - self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) - self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) - - def testEmbeddingWrapper(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 1], dtype=dtypes.int32) - m = array_ops.zeros([1, 2]) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) - self.assertEqual(embedding_cell.output_size, 2) - g, new_m = embedding_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g, new_m], { - x.name: np.array([[1]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[1].shape, (1, 2)) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.17139, 0.17139]]) - - def testEmbeddingWrapperWithDynamicRnn(self): - with self.cached_session() as sess: - with variable_scope.variable_scope("root"): - inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) - input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) - embedding_cell = contrib_rnn.EmbeddingWrapper( - rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), - embedding_classes=1, - embedding_size=2) - outputs, _ = rnn.dynamic_rnn( - cell=embedding_cell, - inputs=inputs, - sequence_length=input_lengths, - dtype=dtypes.float32) - sess.run([variables_lib.global_variables_initializer()]) - # This will fail if output's dtype is inferred from input's. - sess.run(outputs) - - def testMultiRNNCell(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 4]) - multi_rnn_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=False) - _, ml = multi_rnn_cell(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run(ml, { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1, 0.1, 0.1]]) - }) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) - self.assertEqual(len(multi_rnn_cell.weights), 2 * 4) - self.assertTrue( - [x.dtype == dtypes.float32 for x in multi_rnn_cell.weights]) - - def testMultiRNNCellWithStateTuple(self): - with self.cached_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m_bad = array_ops.zeros([1, 4]) - m_good = (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])) - - # Test incorrectness of state - with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): - rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_bad) - - _, ml = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.GRUCell(2) for _ in range(2)], - state_is_tuple=True)(x, m_good) - - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - ml, { - x.name: np.array([[1., 1.]]), - m_good[0].name: np.array([[0.1, 0.1]]), - m_good[1].name: np.array([[0.1, 0.1]]) - }) - - # The numbers in results were not calculated, this is just a - # smoke test. However, these numbers should match those of - # the test testMultiRNNCell. - self.assertAllClose(res[0], [[0.175991, 0.175991]]) - self.assertAllClose(res[1], [[0.13248, 0.13248]]) - - @parameterized.parameters( - [[rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2], - [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]]) - @test_util.run_in_graph_and_eager_modes - def testWrapperKerasStyle(self, wrapper, wrapper_v2): - """Tests if wrapper cell is instantiated in keras style scope.""" - wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1)) - self.assertTrue(wrapped_cell_v2._keras_style) - - wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1)) - self.assertFalse(wrapped_cell._keras_style) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testWrapperV2VariableNames(self, wrapper): - """Tests that variables names do not depend on wrapper in RNN layer.""" - - def _rnn_input(apply_wrapper, name): - """Creates a RNN layer with/without wrapper and returns built rnn cell.""" - with base_layer.keras_style_scope(): - base_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell") - for _ in range(2)]) - if apply_wrapper: - rnn_cell = wrapper(base_cell) - else: - rnn_cell = base_cell - rnn_layer = keras_layers.RNN(rnn_cell, name=name) - inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) - _ = rnn_layer(inputs) - return base_cell._cells[0] - - rnn_1 = _rnn_input(True, name="rnn_0") - rnn_2 = _rnn_input(False, name="rnn_1") - - for i, cell in enumerate([rnn_1, rnn_2]): - var_prefix = "rnn_{}/cell_0/basic_rnn_cell/".format(i) - self.assertCountEqual([v.name for v in cell.weights], - (var_prefix + "kernel:0", var_prefix + "bias:0")) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testWrapperWeights(self, wrapper): - """Tests that wrapper weights contain wrapped cells weights.""" - - with base_layer.keras_style_scope(): - base_cell = rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell") - rnn_cell = wrapper(base_cell) - rnn_layer = keras_layers.RNN(rnn_cell) - inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) - rnn_layer(inputs) - - expected_weights = ["rnn/" + var for var in ("kernel:0", "bias:0")] - self.assertEqual(len(rnn_cell.weights), 2) - self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights) - self.assertCountEqual([v.name for v in rnn_cell.trainable_variables], - expected_weights) - self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables], - []) - self.assertCountEqual([v.name for v in rnn_cell._cell.weights], - expected_weights) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testWrapperV2Caller(self, wrapper): - """Tests that wrapper V2 is using the LayerRNNCell's caller.""" - - with base_layer.keras_style_scope(): - base_cell = rnn_cell_impl.MultiRNNCell( - [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) - rnn_cell = wrapper(base_cell) - inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32) - state = ops.convert_to_tensor([[1]], dtype=dtypes.float32) - _ = rnn_cell(inputs, [state, state]) - weights = base_cell._cells[0].weights - self.assertLen(weights, expected_len=2) - self.assertTrue(all(["_wrapper" in v.name for v in weights])) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2]) - @test_util.run_in_graph_and_eager_modes - def testWrapperV2Build(self, wrapper): - cell = rnn_cell_impl.LSTMCell(10) - wrapper = wrapper(cell) - wrapper.build((1,)) - self.assertTrue(cell.built) - - -@test_util.run_all_in_graph_and_eager_modes -class DropoutWrapperTest(test.TestCase, parameterized.TestCase): - - def _testDropoutWrapper(self, - batch_size=None, - time_steps=None, - parallel_iterations=None, - wrapper_type=None, - scope="root", - **kwargs): - if batch_size is None and time_steps is None: - # 2 time steps, batch size 1, depth 3 - batch_size = 1 - time_steps = 2 - x = constant_op.constant( - [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32) - m = rnn_cell_impl.LSTMStateTuple( - *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2) - else: - x = constant_op.constant( - np.random.randn(time_steps, batch_size, 3).astype(np.float32)) - m = rnn_cell_impl.LSTMStateTuple(*[ - constant_op. - constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2) - outputs, final_state = rnn.dynamic_rnn( - cell=wrapper_type( - rnn_cell_impl.LSTMCell( - 3, initializer=init_ops.constant_initializer(0.5)), - dtype=x.dtype, **kwargs), - time_major=True, - parallel_iterations=parallel_iterations, - inputs=x, - initial_state=m, - scope=scope) - self.evaluate([variables_lib.global_variables_initializer()]) - res = self.evaluate([outputs, final_state]) - self.assertEqual(res[0].shape, (time_steps, batch_size, 3)) - self.assertEqual(res[1].c.shape, (batch_size, 3)) - self.assertEqual(res[1].h.shape, (batch_size, 3)) - return res - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperProperties(self, wrapper_type): - cell = rnn_cell_impl.BasicRNNCell(10) - wrapper = wrapper_type(cell) - # Github issue 15810 - self.assertEqual(wrapper.wrapped_cell, cell) - self.assertEqual(wrapper.state_size, 10) - self.assertEqual(wrapper.output_size, 10) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperZeroState(self, wrapper_type): - class _Cell(rnn_cell_impl.BasicRNNCell): - - def zero_state(self, batch_size=None, dtype=None): - return "wrapped_cell_zero_state" - wrapper = wrapper_type(_Cell(10)) - self.assertEqual(wrapper.zero_state(10, dtypes.float32), - "wrapped_cell_zero_state") - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepAllConstantInput(self, wrapper_type): - keep = array_ops.ones([]) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepAll(self, wrapper_type): - keep = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(true_full_output, res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperWithSeed(self, wrapper_type): - keep_some = 0.5 - random_seed.set_random_seed(2) - ## Use parallel_iterations = 1 in both calls to - ## _testDropoutWrapper to ensure the (per-time step) dropout is - ## consistent across both calls. Otherwise the seed may not end - ## up being munged consistently across both graphs. - res_standard_1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1, - wrapper_type=wrapper_type, - scope="root_1") - random_seed.set_random_seed(2) - res_standard_2 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - seed=10, - parallel_iterations=1, - wrapper_type=wrapper_type, - scope="root_2") - self.assertAllClose(res_standard_1[0], res_standard_2[0]) - self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c) - self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoOutput(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_none, - state_keep_prob=keep_all, - wrapper_type=wrapper_type) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - self.assertAllClose(np.zeros(res[0].shape), res[0]) - self.assertAllClose(true_full_output[1], res[1].h) - self.assertAllClose(true_full_final_c, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - # Even though we dropout state, by default DropoutWrapper never - # drops out the memory ("c") term of an LSTMStateTuple. - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_all, - state_keep_prob=keep_none, - wrapper_type=wrapper_type) - true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - self.assertAllClose(true_full_output[0], res[0][0]) - # Second output is modified by zero input state - self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4) - # h state has been set to zero - self.assertAllClose(np.zeros(res[1].h.shape), res[1].h) - # c state of an LSTMStateTuple is NEVER modified. - self.assertAllClose(true_c_state, res[1].c) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperKeepNoInput(self, wrapper_type): - keep_all = variable_scope.get_variable("all", initializer=1.0) - keep_none = variable_scope.get_variable("none", initializer=1e-6) - true_full_output = np.array( - [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]], - dtype=np.float32) - true_full_final_c = np.array( - [[1.949385, 1.949385, 1.949385]], dtype=np.float32) - # All outputs are different because inputs are zeroed out - res = self._testDropoutWrapper( - input_keep_prob=keep_none, - output_keep_prob=keep_all, - state_keep_prob=keep_all, - wrapper_type=wrapper_type) - self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4) - self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4) - self.assertGreater(np.linalg.norm(res[1].c - true_full_final_c), 1e-4) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentOutput(self, wrapper_type): - keep_some = 0.8 - keep_all = variable_scope.get_variable("all", initializer=1.0) - res = self._testDropoutWrapper( - input_keep_prob=keep_all, - output_keep_prob=keep_some, - state_keep_prob=keep_all, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7) - # Ensure the same dropout pattern for all time steps - output_mask = np.abs(res[0]) > 1e-6 - for m in output_mask[1:]: - self.assertAllClose(output_mask[0], m) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentStateInputAndOutput(self, wrapper_type): - keep_some = 0.9 - res = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7) - - # Smoke test for the state/input masks. - output_mask = np.abs(res[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res[1].c) > 1e-6 - state_h_mask = np.abs(res[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - @parameterized.parameters( - [rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2]) - def testDropoutWrapperRecurrentStateInputAndOutputWithSeed( - self, wrapper_type): - keep_some = 0.9 - random_seed.set_random_seed(2347) - np.random.seed(23487) - res0 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987, - scope="root_0") - random_seed.set_random_seed(2347) - np.random.seed(23487) - res1 = self._testDropoutWrapper( - input_keep_prob=keep_some, - output_keep_prob=keep_some, - state_keep_prob=keep_some, - variational_recurrent=True, - wrapper_type=wrapper_type, - input_size=3, - batch_size=5, - time_steps=7, - seed=-234987, - scope="root_1") - - output_mask = np.abs(res0[0]) > 1e-6 - for time_step in output_mask: - # Ensure the same dropout output pattern for all time steps - self.assertAllClose(output_mask[0], time_step) - for batch_entry in time_step: - # Assert all batch entries get the same mask - self.assertAllClose(batch_entry, time_step[0]) - - # For state, ensure all batch entries have the same mask - state_c_mask = np.abs(res0[1].c) > 1e-6 - state_h_mask = np.abs(res0[1].h) > 1e-6 - for batch_entry in state_c_mask: - self.assertAllClose(batch_entry, state_c_mask[0]) - for batch_entry in state_h_mask: - self.assertAllClose(batch_entry, state_h_mask[0]) - - # Ensure seeded calculation is identical. - self.assertAllClose(res0[0], res1[0]) - self.assertAllClose(res0[1].c, res1[1].c) - self.assertAllClose(res0[1].h, res1[1].h) - - -def basic_rnn_cell(inputs, state, num_units, scope=None): - if state is None: - if inputs is not None: - batch_size = inputs.get_shape()[0] - dtype = inputs.dtype - else: - batch_size = 0 - dtype = dtypes.float32 - init_output = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_state = array_ops.zeros( - array_ops.stack([batch_size, num_units]), dtype=dtype) - init_output.set_shape([batch_size, num_units]) - init_state.set_shape([batch_size, num_units]) - return init_output, init_state - else: - with variable_scope.variable_scope(scope, "basic_rnn_cell", - [inputs, state]): - output = math_ops.tanh( - Linear([inputs, state], num_units, True)([inputs, state])) - return output, output - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index d7ee7fb8faacb0876218a983d68f007e1905c11e..dfac2df6a0d4143106ad0f090805597c26659280 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -22,6 +22,7 @@ import itertools import numpy as np +from tensorflow.contrib.rnn.python.ops import core_rnn_cell as legacy_rnn_cell from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -53,6 +54,294 @@ from tensorflow.python.util import nest class RNNCellTest(test.TestCase): + def testIndRNNCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.IndRNNCell(2) + g, _ = cell(x, m) + self.assertEqual([ + "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME + ], [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + + def testIndyGRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.185265, 0.17704]]) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyGRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.155127, 0.157328]]) + + def testIndyLSTMCell(self): + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) + self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error + g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state_0, out_state_1], { + x.name: np.array([[1., 1.]]), + state_0[0].name: 0.1 * np.ones([1, 2]), + state_0[1].name: 0.1 * np.ones([1, 2]), + state_1[0].name: 0.1 * np.ones([1, 2]), + state_1[1].name: 0.1 * np.ones([1, 2]), + }) + self.assertEqual(len(res), 3) + global_variables = variables.global_variables() + self.assertEqual(expected_variable_names, + [v.name for v in global_variables]) + # Only check the range of outputs as this is just a smoke test. + self.assertAllInRange(res[0], -1.0, 1.0) + self.assertAllInRange(res[1], -1.0, 1.0) + self.assertAllInRange(res[2], -1.0, 1.0) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_state], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + }) + self.assertEqual(len(res), 2) + + def testLSTMCellLayerNorm(self): + with self.cached_session() as sess: + num_units = 2 + num_proj = 3 + batch_size = 1 + input_size = 4 + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([batch_size, input_size]) + c = array_ops.zeros([batch_size, num_units]) + h = array_ops.zeros([batch_size, num_proj]) + state = rnn_cell_impl.LSTMStateTuple(c, h) + cell = contrib_rnn_cell.LayerNormLSTMCell( + num_units=num_units, + num_proj=num_proj, + forget_bias=1.0, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0) + g, out_m = cell(x, state) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [g, out_m], { + x.name: np.ones((batch_size, input_size)), + c.name: 0.1 * np.ones((batch_size, num_units)), + h.name: 0.1 * np.ones((batch_size, num_proj)) + }) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is mostly just a + # smoke test. + self.assertEqual(res[0].shape, (batch_size, num_proj)) + self.assertEqual(res[1][0].shape, (batch_size, num_units)) + self.assertEqual(res[1][1].shape, (batch_size, num_proj)) + # Different inputs so different outputs and states + for i in range(1, batch_size): + self.assertTrue( + float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) + self.assertTrue( + float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + + def testOutputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.OutputProjectionWrapper( + rnn_cell_impl.GRUCell(3), 2) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.231907, 0.231907]]) + + def testInputProjectionWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 3]) + cell = legacy_rnn_cell.InputProjectionWrapper( + rnn_cell_impl.GRUCell(3), num_proj=3) + g, new_m = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 3)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) + + def testEmbeddingWrapper(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 1], dtype=dtypes.int32) + m = array_ops.zeros([1, 2]) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.GRUCell(2), embedding_classes=3, embedding_size=2) + self.assertEqual(embedding_cell.output_size, 2) + g, new_m = embedding_cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, new_m], { + x.name: np.array([[1]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[1].shape, (1, 2)) + # The numbers in results were not calculated, this is just a smoke test. + self.assertAllClose(res[0], [[0.17139, 0.17139]]) + + def testEmbeddingWrapperWithDynamicRnn(self): + with self.cached_session() as sess: + with variable_scope.variable_scope("root"): + inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) + input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) + embedding_cell = legacy_rnn_cell.EmbeddingWrapper( + rnn_cell_impl.BasicLSTMCell(1, state_is_tuple=True), + embedding_classes=1, + embedding_size=2) + outputs, _ = rnn.dynamic_rnn( + cell=embedding_cell, + inputs=inputs, + sequence_length=input_lengths, + dtype=dtypes.float32) + sess.run([variables.global_variables_initializer()]) + # This will fail if output's dtype is inferred from input's. + sess.run(outputs) + + def testSRUCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + + def testSRUCellKerasRNN(self): + """Tests that SRUCell works with keras RNN layer.""" + cell = contrib_rnn_cell.SRUCell(10) + seq_input = ops.convert_to_tensor( + np.random.rand(2, 3, 5), name="seq_input", dtype=dtypes.float32) + rnn_layer = keras_layers.RNN(cell=cell) + rnn_outputs_keras = rnn_layer(seq_input) + with self.cached_session() as sess: + sess.run([variables.global_variables_initializer()]) + self.assertEqual(sess.run(rnn_outputs_keras).shape, (2, 10)) + + def testSRUCellBiasType(self): + """Tests that the bias' dtype is properly set.""" + cell = contrib_rnn_cell.SRUCell(10) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.float32_ref) + + cell = contrib_rnn_cell.SRUCell(10, dtype=dtypes.int32) + cell.build((2, 3, 5)) + self.assertEqual(cell._bias.dtype, dtypes.int32_ref) + + cell_input = ops.convert_to_tensor( + np.random.rand(2, 5), name="cell_input", dtype=dtypes.float16) + cell_state = ops.convert_to_tensor( + np.random.rand(2, 10), name="cell_state", dtype=dtypes.float16) + cell = contrib_rnn_cell.SRUCell(10) + cell(cell_input, [cell_state]) + self.assertEqual(cell._bias.dtype, dtypes.float16_ref) + + def testSRUCellWithDiffSize(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) + def testCoupledInputForgetGateLSTMCell(self): with self.cached_session() as sess: num_units = 2 diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index 0266b72dcb15e4aba01a9a31b4be75c5b84d44da..41b1698321e20f4360d75fa2db79f9bd8a806cea 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -131,7 +131,8 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=None, parallel_iterations=None, time_major=False, - scope=None): + scope=None, + swap_memory=False): """Creates a dynamic bidirectional recurrent neural network. Stacks several bidirectional rnn layers. The combined forward and backward @@ -171,6 +172,10 @@ def stack_bidirectional_dynamic_rnn(cells_fw, data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to None. + swap_memory: Transparently swap the tensors produced in forward inference + but needed for back prop from GPU to CPU. This allows training RNNs + which would typically not fit on a single GPU, with very minimal (or no) + performance penalty. Returns: A tuple (outputs, output_state_fw, output_state_bw) where: @@ -230,6 +235,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, sequence_length=sequence_length, parallel_iterations=parallel_iterations, dtype=dtype, + swap_memory=swap_memory, time_major=time_major) # Concat the outputs to create the new input. prev_layer = array_ops.concat(outputs, 2) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 482e547a16be85804beec88a91fa03b053d09b27..d25afc8b9c4381fb3b0092ef21f46646353e1b8e 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -3153,7 +3153,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): r"""Independently Gated Recurrent Unit cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, - yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and + yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal matrices, i.e. a Hadamard product with a single vector: @@ -3164,12 +3164,10 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU node sees only its own state, as opposed to seeing all states in the same layer. - TODO(gonnet): Write a paper describing this and add a reference here. - Args: num_units: int, The number of units in the GRU cell. activation: Nonlinearity to use. Default: `tanh`. @@ -3254,7 +3252,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): self.built = True def call(self, inputs, state): - """Gated recurrent unit (GRU) with nunits cells.""" + """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells.""" gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) @@ -3278,10 +3276,9 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): r"""Basic IndyLSTM recurrent network cell. Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to - BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\) - matrices in - https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate - replaced by diagonal matrices, i.e. a Hadamard product with a single vector: + BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\) + matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a + Hadamard product with a single vector: $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ @@ -3289,8 +3286,8 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): $$c_t = f_t \circ c_{t-1} + i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ - where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM - node sees only its own state \(h\) and \(c\), as opposed to seeing all + where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM + node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all states in the same layer. We add forget_bias (default: 1) to the biases of the forget gate in order to @@ -3298,11 +3295,6 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - - For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` - that follows. - - TODO(gonnet): Write a paper describing this and add a reference here. """ def __init__(self, diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 0392ed9eee79391c60318faf68d8dfd6eb64a994..a61e9579b84a60d74b73e45a6100a2c772d9cff8 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -22,5 +22,5 @@ from tensorflow.python.keras import saving # TODO(kathywu): Remove all contrib callers, switch to tf.keras. -save_keras_model = saving.export +save_keras_model = saving.export_saved_model load_keras_model = saving.load_from_saved_model diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 2a70b08f5c46e11e7fd83fe134741b9a241153f5..8e2ce82294287dda07d2067c5b9f012f510dbd08 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -269,4 +269,5 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], + shard_count = 4, ) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index d815f81f847ad79ddcc6c6ecf5c050598e185d8d..1a5692f7b5be5e87b78dac9d1ae51f280ca089f8 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -358,7 +358,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00597103), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -387,7 +387,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -454,7 +454,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333333333)) + shape=(5, 3), dtype=dtype('int32'), mean=1.4)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -696,7 +696,7 @@ class AttentionWrapperTest(test.TestCase): rnn_output=ResultSummary( shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0025896581), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=1.6)) + shape=(5, 3), dtype=dtype('int32'), mean=1.73333333)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -707,12 +707,12 @@ class AttentionWrapperTest(test.TestCase): shape=(5, 6), dtype=dtype('float32'), mean=-0.00069823361), time=3, alignments=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), attention_state=ResultSummary( - shape=(5, 8), dtype=dtype('float32'), mean=0.028698336), + shape=(5, 8), dtype=dtype('float32'), mean=0.029914695), alignment_history=()) expected_final_alignment_history = ResultSummary( - shape=(3, 5, 8), dtype=dtype('float32'), mean=0.04865776002407074) + shape=(3, 5, 8), dtype=dtype('float32'), mean=0.0465225502849) self._testWithAttention( create_attention_mechanism, @@ -921,9 +921,9 @@ class AttentionWrapperTest(test.TestCase): expected_final_output = BasicDecoderOutput( rnn_output=ResultSummary( - shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966), + shape=(5, 3, 20), dtype=dtype('float32'), mean=0.115853324533), sample_id=ResultSummary( - shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667)) + shape=(5, 3), dtype=dtype('int32'), mean=8.6)) expected_final_state = AttentionWrapperState( cell_state=LSTMStateTuple( c=ResultSummary( @@ -931,7 +931,7 @@ class AttentionWrapperTest(test.TestCase): h=ResultSummary( shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)), attention=ResultSummary( - shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207), + shape=(5, 20), dtype=dtype('float32'), mean=0.11462739855), time=3, alignments=(ResultSummary( shape=(5, 8), dtype=dtype('float32'), mean=0.125), diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py index 7ff04e1780c4c44df14d6e87c5afdbf533ca5c90..5ee01f66f165bd2ac22cae10807f24f6b97f0c64 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -17,14 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from absl.testing import parameterized import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper -from tensorflow.python.framework import ops +from tensorflow.contrib.seq2seq.python.ops import basic_decoder +from tensorflow.contrib.seq2seq.python.ops import sampler as sampler_py +from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.keras import initializers +from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util import nest @test_util.run_all_in_graph_and_eager_modes @@ -37,13 +46,10 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): self.memory_size = 6 self.units = 8 - self.memory = ops.convert_to_tensor( - np.random.random((self.batch, self.timestep, self.memory_size)), - dtype=np.float32) - self.query = ops.convert_to_tensor( - np.random.random((self.batch, self.units)), dtype=np.float32) - self.state = ops.convert_to_tensor( - np.random.random((self.batch, self.timestep)), dtype=np.float32) + self.memory = np.random.randn(self.batch, self.timestep, + self.memory_size).astype(np.float32) + self.query = np.random.randn(self.batch, self.units).astype(np.float32) + self.state = np.random.randn(self.batch, self.timestep).astype(np.float32) @parameterized.named_parameters( ("luong", wrapper.LuongAttentionV2), @@ -52,8 +58,8 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), ) def test_attention_shape_inference(self, attention_cls): - attention = attention_cls(self.units) - attention_score = attention([self.query, self.state, self.memory]) + attention = attention_cls(self.units, self.memory) + attention_score = attention([self.query, self.state]) self.assertLen(attention_score, 2) self.assertEqual(attention_score[0].shape, (self.batch, self.timestep)) self.assertEqual(attention_score[1].shape, (self.batch, self.timestep)) @@ -65,7 +71,7 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), ) def test_get_config(self, attention_cls): - attention = attention_cls(self.units) + attention = attention_cls(self.units, self.memory) config = attention.get_config() attention_from_config = attention_cls.from_config(config) @@ -80,9 +86,8 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), ) def test_layer_output(self, attention_cls): - attention = attention_cls(self.units) - - score = attention([self.query, self.state, self.memory]) + attention = attention_cls(self.units, self.memory) + score = attention([self.query, self.state]) self.evaluate(variables.variables_initializer(attention.variables)) score_val = self.evaluate(score) @@ -90,5 +95,651 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): self.assertEqual(score_val[0].shape, (self.batch, self.timestep)) self.assertEqual(score_val[1].shape, (self.batch, self.timestep)) + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_passing_memory_from_call(self, attention_cls): + attention = attention_cls(self.units, self.memory) + weights_before_query = attention.get_weights() + ref_score = attention([self.query, self.state]) + + self.evaluate(variables.global_variables_initializer()) + ref_score_val = self.evaluate(ref_score) + + all_weights = attention.get_weights() + config = attention.get_config() + # Simulate the twice invocation of calls here. + attention_from_config = attention_cls.from_config(config) + attention_from_config.build(self.memory.shape) + attention_from_config.set_weights(weights_before_query) + attention_from_config(self.memory, setup_memory=True) + attention_from_config.build([self.query.shape, self.state.shape]) + attention_from_config.set_weights(all_weights) + score = attention_from_config([self.query, self.state]) + + score_val = self.evaluate(score) + self.assertAllClose(ref_score_val, score_val) + + @parameterized.named_parameters( + ("luong", wrapper.LuongAttentionV2), + ("luong_monotonic", wrapper.LuongMonotonicAttentionV2), + ("bahdanau", wrapper.BahdanauAttentionV2), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttentionV2), + ) + def test_save_load_layer(self, attention_cls): + vocab = 20 + embedding_dim = 6 + inputs = keras.layers.Input(shape=[self.timestep]) + encoder_input = keras.layers.Embedding( + vocab, embedding_dim, mask_zero=True)( + inputs) + encoder_output = keras.layers.UnifiedLSTM( + self.memory_size, return_sequences=True)( + encoder_input) + + attention = attention_cls(self.units, encoder_output) + query = keras.layers.Input(shape=[self.units]) + state = keras.layers.Input(shape=[self.timestep]) + + score = attention([query, state]) + + x = np.random.randint(vocab, size=(self.batch, self.timestep)) + x_test = np.random.randint(vocab, size=(self.batch, self.timestep)) + y = np.random.randn(self.batch, self.timestep) + model = keras.models.Model([inputs, query, state], score) + model.compile("rmsprop", "mse") + model.fit([x, self.query, self.state], (y, y)) + y_ref = model.predict_on_batch([x_test, self.query, self.state]) + + config = model.get_config() + weights = model.get_weights() + loaded_model = keras.models.Model.from_config( + config, custom_objects={attention_cls.__name__: attention_cls}) + loaded_model.set_weights(weights) + + y = loaded_model.predict_on_batch([x_test, self.query, self.state]) + + self.assertAllClose(y_ref, y) + + # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) + + +class ResultSummary( + collections.namedtuple("ResultSummary", ("shape", "dtype", "mean"))): + pass + + +def get_result_summary(x): + if isinstance(x, np.ndarray): + return ResultSummary(x.shape, x.dtype, x.mean()) + return x + + +@test_util.run_all_in_graph_and_eager_modes +class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): + + def assertAllCloseOrEqual(self, x, y, **kwargs): + if isinstance(x, np.ndarray) or isinstance(x, float): + return super(AttentionWrapperV2Test, self).assertAllClose( + x, y, atol=1e-3, **kwargs) + else: + self.assertAllEqual(x, y, **kwargs) + + def setUp(self): + super(AttentionWrapperV2Test, self).setUp() + self.batch = 64 + self.units = 128 + self.encoder_timestep = 10 + self.encoder_dim = 256 + self.decoder_timestep = 12 + self.encoder_outputs = np.random.randn(self.batch, self.encoder_timestep, + self.encoder_dim) + self.encoder_sequence_length = np.random.randint( + self.encoder_timestep, size=(self.batch,)).astype(np.int32) + self.decoder_inputs = np.random.randn(self.batch, self.decoder_timestep, + self.units) + self.decoder_sequence_length = np.random.randint( + self.decoder_timestep, size=(self.batch,)).astype(np.int32) + + def _testWithAttention(self, + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=3, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_size=6, + attention_layer=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + attention_layer_sizes = ([attention_layer_size] + if attention_layer_size is not None else None) + attention_layers = ([attention_layer] + if attention_layer is not None else None) + self._testWithMaybeMultiAttention( + is_multi=False, + create_attention_mechanisms=[create_attention_mechanism], + expected_final_output=expected_final_output, + expected_final_state=expected_final_state, + attention_mechanism_depths=[attention_mechanism_depth], + alignment_history=alignment_history, + expected_final_alignment_history=expected_final_alignment_history, + attention_layer_sizes=attention_layer_sizes, + attention_layers=attention_layers, + create_query_layer=create_query_layer, + create_memory_layer=create_memory_layer, + create_attention_kwargs=create_attention_kwargs) + + def _testWithMaybeMultiAttention(self, + is_multi, + create_attention_mechanisms, + expected_final_output, + expected_final_state, + attention_mechanism_depths, + alignment_history=False, + expected_final_alignment_history=None, + attention_layer_sizes=None, + attention_layers=None, + create_query_layer=False, + create_memory_layer=True, + create_attention_kwargs=None): + # Allow is_multi to be True with a single mechanism to enable test for + # passing in a single mechanism in a list. + assert len(create_attention_mechanisms) == 1 or is_multi + encoder_sequence_length = [3, 2, 3, 1, 1] + decoder_sequence_length = [2, 0, 1, 2, 3] + batch_size = 5 + encoder_max_time = 8 + decoder_max_time = 4 + input_depth = 7 + encoder_output_depth = 10 + cell_depth = 9 + create_attention_kwargs = create_attention_kwargs or {} + + if attention_layer_sizes is not None: + # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. + attention_depth = sum(attention_layer_size or encoder_output_depth + for attention_layer_size in attention_layer_sizes) + elif attention_layers is not None: + # Compute sum of attention_layers output depth. + attention_depth = sum( + attention_layer.compute_output_shape( + [batch_size, cell_depth + encoder_output_depth]).dims[-1].value + for attention_layer in attention_layers) + else: + attention_depth = encoder_output_depth * len(create_attention_mechanisms) + + decoder_inputs = np.random.randn(batch_size, decoder_max_time, + input_depth).astype(np.float32) + encoder_outputs = np.random.randn(batch_size, encoder_max_time, + encoder_output_depth).astype(np.float32) + + attention_mechanisms = [] + for creator, depth in zip(create_attention_mechanisms, + attention_mechanism_depths): + # Create a memory layer with deterministic initializer to avoid randomness + # in the test between graph and eager. + if create_query_layer: + create_attention_kwargs["query_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + if create_memory_layer: + create_attention_kwargs["memory_layer"] = keras.layers.Dense( + depth, kernel_initializer="ones", use_bias=False) + + attention_mechanisms.append( + creator( + units=depth, + memory=encoder_outputs, + memory_sequence_length=encoder_sequence_length, + **create_attention_kwargs)) + + with self.cached_session(use_gpu=True): + attention_layer_size = attention_layer_sizes + attention_layer = attention_layers + if not is_multi: + if attention_layer_size is not None: + attention_layer_size = attention_layer_size[0] + if attention_layer is not None: + attention_layer = attention_layer[0] + cell = rnn_cell.LSTMCell(cell_depth, initializer="ones") + cell = wrapper.AttentionWrapper( + cell, + attention_mechanisms if is_multi else attention_mechanisms[0], + attention_layer_size=attention_layer_size, + alignment_history=alignment_history, + attention_layer=attention_layer) + # Set the attention_layer within AttentionWrapper to have deterministic + # kernel initializer, for testing purpose. + if cell._attention_layers is not None: + for layer in cell._attention_layers: + if getattr(layer, "kernel_initializer") is None: + layer.kernel_initializer = initializers.ones() + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + initial_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size) + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=initial_state, + sequence_length=decoder_sequence_length) + + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + expected_time = ( + expected_final_state.time if context.executing_eagerly() else None) + self.assertEqual((batch_size, expected_time, attention_depth), + tuple(final_outputs.rnn_output.get_shape().as_list())) + self.assertEqual((batch_size, expected_time), + tuple(final_outputs.sample_id.get_shape().as_list())) + + self.assertEqual((batch_size, attention_depth), + tuple(final_state.attention.get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state.c.get_shape().as_list())) + self.assertEqual((batch_size, cell_depth), + tuple(final_state.cell_state.h.get_shape().as_list())) + + if alignment_history: + if is_multi: + state_alignment_history = [] + for history_array in final_state.alignment_history: + history = history_array.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(history.get_shape().as_list())) + state_alignment_history.append(history) + state_alignment_history = tuple(state_alignment_history) + else: + state_alignment_history = final_state.alignment_history.stack() + self.assertEqual((expected_time, batch_size, encoder_max_time), + tuple(state_alignment_history.get_shape().as_list())) + nest.assert_same_structure(cell.state_size, + cell.zero_state(batch_size, dtypes.float32)) + # Remove the history from final_state for purposes of the + # remainder of the tests. + final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access + else: + state_alignment_history = () + + self.evaluate(variables.global_variables_initializer()) + eval_result = self.evaluate({ + "final_outputs": final_outputs, + "final_state": final_state, + "state_alignment_history": state_alignment_history, + }) + + final_output_info = nest.map_structure(get_result_summary, + eval_result["final_outputs"]) + final_state_info = nest.map_structure(get_result_summary, + eval_result["final_state"]) + print("final_output_info: ", final_output_info) + print("final_state_info: ", final_state_info) + + nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, + final_output_info) + nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, + final_state_info) + if alignment_history: # by default, the wrapper emits attention as output + final_alignment_history_info = nest.map_structure( + get_result_summary, eval_result["state_alignment_history"]) + print("final_alignment_history_info: ", final_alignment_history_info) + nest.map_structure( + self.assertAllCloseOrEqual, + # outputs are batch major but the stacked TensorArray is time major + expected_final_alignment_history, + final_alignment_history_info) + + @parameterized.parameters([np.float16, np.float32, np.float64]) + def _testBahdanauNormalizedDType(self, dtype): + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.BahdanauAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + normalize=True, + dtype=dtype) + cell = rnn_cell.LSTMCell(self.units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + @parameterized.parameters([np.float16, np.float32, np.float64]) + def testLuongScaledDType(self, dtype): + # Test case for GitHub issue 18099 + encoder_outputs = self.encoder_outputs.astype(dtype) + decoder_inputs = self.decoder_inputs.astype(dtype) + attention_mechanism = wrapper.LuongAttentionV2( + units=self.units, + memory=encoder_outputs, + memory_sequence_length=self.encoder_sequence_length, + scale=True, + dtype=dtype, + ) + cell = rnn_cell.LSTMCell(self.units) + cell = wrapper.AttentionWrapper(cell, attention_mechanism) + + sampler = sampler_py.TrainingSampler() + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + + final_outputs, final_state, _ = my_decoder( + decoder_inputs, + initial_state=cell.zero_state(dtype=dtype, batch_size=self.batch), + sequence_length=self.decoder_sequence_length) + self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) + self.assertEqual(final_outputs.rnn_output.dtype, dtype) + self.assertIsInstance(final_state, wrapper.AttentionWrapperState) + self.assertIsInstance(final_state.cell_state, rnn_cell.LSTMStateTuple) + + def testBahdanauNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype(np.float32), mean=4.8290324), + sample_id=ResultSummary(shape=(5, 3), dtype=np.dtype(np.int32), mean=0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=1.6432636), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype(np.float32), mean=0.75866824)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype(np.float32), mean=6.7445569), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype(np.float32), mean=0.125), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype(np.float32), mean=0.125) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + create_query_layer=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauNormalized(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", "normalize": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.9548259), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.4652209), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.70997983)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=6.3075728), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongNotNormalized(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=4.084631), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9) + + def testLuongScaled(self): + create_attention_mechanism = wrapper.LuongAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=2.6605489), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.88403547), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.37819088)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=4.0846314), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + create_attention_kwargs=create_attention_kwargs) + + def testNotUseAttentionLayer(self): + create_attention_mechanism = wrapper.BahdanauAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 10), dtype=np.dtype("float32"), mean=0.072406612), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=3.86666666)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.032002), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.61177742)), + attention=ResultSummary( + shape=(5, 10), dtype=np.dtype("float32"), mean=0.011346335), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.125), + alignment_history=()) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_layer_size=None, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones"} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=5.9850435), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6752492), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.76052248)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=8.361186), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.10989678), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.117412611) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testBahdanauMonotonicNormalized(self): + create_attention_mechanism = wrapper.BahdanauMonotonicAttentionV2 + create_attention_kwargs = {"kernel_initializer": "ones", + "normalize": True} + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=4.5706983), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.6005473), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.77863038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=7.3326721), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.12258384), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.12258384) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_query_layer=True, + create_attention_kwargs=create_attention_kwargs) + + def testLuongMonotonicNotNormalized(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history) + + def testLuongMonotonicScaled(self): + create_attention_mechanism = wrapper.LuongMonotonicAttentionV2 + create_attention_kwargs = {"scale": True} + + expected_final_output = basic_decoder.BasicDecoderOutput( + rnn_output=ResultSummary( + shape=(5, 3, 6), dtype=np.dtype("float32"), mean=3.159497), + sample_id=ResultSummary( + shape=(5, 3), dtype=np.dtype("int32"), mean=0.0)) + expected_final_state = wrapper.AttentionWrapperState( + cell_state=rnn_cell.LSTMStateTuple( + c=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=1.072384), + h=ResultSummary( + shape=(5, 9), dtype=np.dtype("float32"), mean=0.50331038)), + attention=ResultSummary( + shape=(5, 6), dtype=np.dtype("float32"), mean=5.3079605), + time=3, + alignments=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + attention_state=ResultSummary( + shape=(5, 8), dtype=np.dtype("float32"), mean=0.11467695), + alignment_history=()) + expected_final_alignment_history = ResultSummary( + shape=(3, 5, 8), dtype=np.dtype("float32"), mean=0.11899644) + + self._testWithAttention( + create_attention_mechanism, + expected_final_output, + expected_final_state, + attention_mechanism_depth=9, + alignment_history=True, + expected_final_alignment_history=expected_final_alignment_history, + create_attention_kwargs=create_attention_kwargs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py index 15fb688fc4dd4909e5bab36def7ac58e9d7be4ea..2341ebb77ab6ecad1e979bc8bed0080128a804da 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_v2_test.py @@ -127,7 +127,7 @@ class BasicDecoderTest(keras_parameterized.TestCase): np.argmax(eval_result["step_outputs"].rnn_output, -1), eval_result["step_outputs"].sample_id) - def testStepWithGreedyEmbeddingHelper(self): + def DISABLED_testStepWithGreedyEmbeddingHelper(self): batch_size = 5 vocabulary_size = 7 cell_depth = vocabulary_size # cell's logits must match vocabulary size diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 5e28e651c666b1c448f778fc9c02d637ce817bae..56f2a0acc9f2e6f951c5df26a53a31645697da4f 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -25,10 +25,13 @@ from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -530,11 +533,10 @@ class BeamSearchDecoderTest(test.TestCase): return (shape[1], shape[0]) + shape[2:] return shape - self.assertTrue( - isinstance(final_outputs, - beam_search_decoder.FinalBeamSearchDecoderOutput)) - self.assertTrue( - isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) beam_search_decoder_output = final_outputs.beam_search_decoder_output self.assertEqual( @@ -574,5 +576,119 @@ class BeamSearchDecoderTest(test.TestCase): with_alignment_history=True) +@test_util.run_all_in_graph_and_eager_modes +class BeamSearchDecoderV2Test(test.TestCase): + + def _testDynamicDecodeRNN(self, time_major, has_attention, + with_alignment_history=False): + encoder_sequence_length = np.array([3, 2, 3, 1, 1]) + decoder_sequence_length = np.array([2, 0, 1, 2, 3]) + batch_size = 5 + decoder_max_time = 4 + input_depth = 7 + cell_depth = 9 + attention_depth = 6 + vocab_size = 20 + end_token = vocab_size - 1 + start_token = 0 + embedding_dim = 50 + max_out = max(decoder_sequence_length) + output_layer = layers.Dense(vocab_size, use_bias=True, activation=None) + beam_width = 3 + + with self.cached_session(): + batch_size_tensor = constant_op.constant(batch_size) + embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) + cell = rnn_cell.LSTMCell(cell_depth) + initial_state = cell.zero_state(batch_size, dtypes.float32) + coverage_penalty_weight = 0.0 + if has_attention: + coverage_penalty_weight = 0.2 + inputs = array_ops.placeholder_with_default( + np.random.randn(batch_size, decoder_max_time, input_depth).astype( + np.float32), + shape=(None, None, input_depth)) + tiled_inputs = beam_search_decoder.tile_batch( + inputs, multiplier=beam_width) + tiled_sequence_length = beam_search_decoder.tile_batch( + encoder_sequence_length, multiplier=beam_width) + attention_mechanism = attention_wrapper.BahdanauAttention( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + initial_state = beam_search_decoder.tile_batch( + initial_state, multiplier=beam_width) + cell = attention_wrapper.AttentionWrapper( + cell=cell, + attention_mechanism=attention_mechanism, + attention_layer_size=attention_depth, + alignment_history=with_alignment_history) + cell_state = cell.zero_state( + dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) + if has_attention: + cell_state = cell_state.clone(cell_state=initial_state) + bsd = beam_search_decoder.BeamSearchDecoderV2( + cell=cell, + beam_width=beam_width, + output_layer=output_layer, + length_penalty_weight=0.0, + coverage_penalty_weight=coverage_penalty_weight, + output_time_major=time_major, + maximum_iterations=max_out) + + final_outputs, final_state, final_sequence_lengths = bsd( + embedding, + start_tokens=array_ops.fill([batch_size_tensor], start_token), + end_token=end_token, + initial_state=cell_state) + + def _t(shape): + if time_major: + return (shape[1], shape[0]) + shape[2:] + return shape + + self.assertIsInstance( + final_outputs, beam_search_decoder.FinalBeamSearchDecoderOutput) + self.assertIsInstance( + final_state, beam_search_decoder.BeamSearchDecoderState) + + beam_search_decoder_output = final_outputs.beam_search_decoder_output + expected_seq_length = 3 if context.executing_eagerly() else None + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(beam_search_decoder_output.scores.get_shape().as_list())) + self.assertEqual( + _t((batch_size, expected_seq_length, beam_width)), + tuple(final_outputs.predicted_ids.get_shape().as_list())) + + self.evaluate(variables.global_variables_initializer()) + eval_results = self.evaluate({ + 'final_outputs': final_outputs, + 'final_sequence_lengths': final_sequence_lengths + }) + + max_sequence_length = np.max(eval_results['final_sequence_lengths']) + + # A smoke test + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), + eval_results['final_outputs'].beam_search_decoder_output.scores.shape) + self.assertEqual( + _t((batch_size, max_sequence_length, beam_width)), eval_results[ + 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) + + def testDynamicDecodeRNNBatchMajorNoAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=False) + + def testDynamicDecodeRNNBatchMajorYesAttention(self): + self._testDynamicDecodeRNN(time_major=False, has_attention=True) + + def testDynamicDecodeRNNBatchMajorYesAttentionWithAlignmentHistory(self): + self._testDynamicDecodeRNN( + time_major=False, + has_attention=True, + with_alignment_history=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index ae3e7f1b5d8c9f06b5defbaee9cad3810e58abd4..79c2ac2f500307ba23b6d97a7a30c6d04cea5176 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -25,10 +25,13 @@ import math import numpy as np from tensorflow.contrib.framework.python.framework import tensor_util +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import initializers from tensorflow.python.keras import layers +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base as layers_base from tensorflow.python.layers import core as layers_core from tensorflow.python.ops import array_ops @@ -225,18 +228,39 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): 1. Storing the query and memory layers. 2. Preprocessing and storing the memory. - Note that this layer only support Keras functional API since it takes multiple - input tensors, which is not available in sequential model. + Note that this layer takes memory as its init parameter, which is an + anti-pattern of Keras API, we have to keep the memory as init parameter for + performance and dependency reason. Under the hood, during `__init__()`, it + will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let + keras to keep track of the memory tensor as the input of this layer. Once + the `__init__()` is done, then user can query the attention by + `score = att_obj([query, state])`, and use it as a normal keras layer. + + Special attention is needed when adding using this class as the base layer for + new attention: + 1. Build() could be invoked at least twice. So please make sure weights are + not duplicated. + 2. Layer.get_weights() might return different set of weights if the instance + has `query_layer`. The query_layer weights is not initialized until the + memory is configured. + + Also note that this layer does not work with Keras model when + `model.compile(run_eagerly=True)` due to the fact that this layer is stateful. + The support for that will be added in a future version. """ def __init__(self, + memory, probability_fn, query_layer=None, memory_layer=None, + memory_sequence_length=None, **kwargs): """Construct base AttentionMechanism class. Args: + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. probability_fn: A `callable`. Converts the score and previous alignments to probabilities. Its signature should be: `probabilities = probability_fn(score, state)`. @@ -247,6 +271,9 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): depth must match the depth of `query_layer`. If `memory_layer` is not provided, the shape of `memory` must match that of `query_layer`. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. **kwargs: Dictionary that contains other common arguments for layer creation. """ @@ -273,20 +300,127 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): self.batch_size = None self._memory_initialized = False self._check_inner_dims_defined = True + self.supports_masking = True + self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) + + if memory is not None: + # Setup the memory by self.__call__() with memory and memory_seq_length. + # This will make the attention follow the keras convention which takes + # all the tensor inputs via __call__(). + if memory_sequence_length is None: + inputs = memory + else: + inputs = [memory, memory_sequence_length] + + self.values = super(_BaseAttentionMechanismV2, self).__call__( + inputs, setup_memory=True) def build(self, input_shape): - # The layer suppose to take 3 inputs, [query, state, memory]. - query_input_shape, _, memory_input_shape = input_shape - if self.query_layer is not None: - self.query_layer.build(query_input_shape) - if self.memory_layer is not None: - self.memory_layer.build(memory_input_shape) - # dtype of the layer is known at this moment, create the score_mask_value if - # needed. - self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) - self.built = True + if not self._memory_initialized: + # This is for setting up the memory, which contains memory and optional + # memory_sequence_length. Build the memory_layer with memory shape. + if self.memory_layer is not None and not self.memory_layer.built: + if isinstance(input_shape, list): + self.memory_layer.build(input_shape[0]) + else: + self.memory_layer.build(input_shape) + else: + # The input_shape should be query.shape and state.shape. Use the query + # to init the query layer. + if self.query_layer is not None and not self.query_layer.built: + self.query_layer.build(input_shape[0]) + + def __call__(self, inputs, **kwargs): + """Preprocess the inputs before calling `base_layer.__call__()`. + + Note that there are situation here, one for setup memory, and one with + actual query and state. + 1. When the memory has not been configured, we just pass all the param to + base_layer.__call__(), which will then invoke self.call() with proper + inputs, which allows this class to setup memory. + 2. When the memory has already been setup, the input should contain query + and state, and optionally processed memory. If the processed memory is + not included in the input, we will have to append it to the inputs and + give it to the base_layer.__call__(). The processed memory is the output + of first invocation of self.__call__(). If we don't add it here, then from + keras perspective, the graph is disconnected since the output from + previous call is never used. - def _setup_memory(self, memory, memory_mask=None): + Args: + inputs: the inputs tensors. + **kwargs: dict, other keyeword arguments for the `__call__()` + """ + if self._memory_initialized: + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have 2 or 3 tensors, got %d" % + len(inputs)) + if len(inputs) == 2: + # We append the calculated memory here so that the graph will be + # connected. + inputs.append(self.values) + return super(_BaseAttentionMechanismV2, self).__call__(inputs, **kwargs) + + def call(self, inputs, mask=None, setup_memory=False, **kwargs): + """Setup the memory or query the attention. + + There are two case here, one for setup memory, and the second is query the + attention score. `setup_memory` is the flag to indicate which mode it is. + The input list will be treated differently based on that flag. + + Args: + inputs: a list of tensor that could either be `query` and `state`, or + `memory` and `memory_sequence_length`. + `query` is the tensor of dtype matching `memory` and shape + `[batch_size, query_depth]`. + `state` is the tensor of dtype matching `memory` and shape + `[batch_size, alignments_size]`. (`alignments_size` is memory's + `max_time`). + `memory` is the memory to query; usually the output of an RNN encoder. + The tensor should be shaped `[batch_size, max_time, ...]`. + `memory_sequence_length` (optional) is the sequence lengths for the + batch entries in memory. If provided, the memory tensor rows are masked + with zeros for values past the respective sequence lengths. + mask: optional bool tensor with shape `[batch, max_time]` for the mask of + memory. If it is not None, the corresponding item of the memory should + be filtered out during calculation. + setup_memory: boolean, whether the input is for setting up memory, or + query attention. + **kwargs: Dict, other keyword arguments for the call method. + Returns: + Either processed memory or attention score, based on `setup_memory`. + """ + if setup_memory: + if isinstance(inputs, list): + if len(inputs) not in (1, 2): + raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" % + len(inputs)) + memory = inputs[0] + memory_sequence_length = inputs[1] if len(inputs) == 2 else None + memory_mask = mask + else: + memory, memory_sequence_length = inputs, None + memory_mask = mask + self._setup_memory(memory, memory_sequence_length, memory_mask) + # We force the self.built to false here since only memory is initialized, + # but the real query/state has not been call() yet. The layer should be + # build and call again. + self.built = False + # Return the processed memory in order to create the Keras connectivity + # data for it. + return self.values + else: + if not self._memory_initialized: + raise ValueError("Cannot query the attention before the setup of " + "memory") + if len(inputs) not in (2, 3): + raise ValueError("Expect the inputs to have query, state, and optional " + "processed memory, got %d items" % len(inputs)) + # Ignore the rest of the inputs and only care about the query and state + query, state = inputs[0], inputs[1] + return self._calculate_attention(query, state) + + def _setup_memory(self, memory, memory_sequence_length=None, + memory_mask=None): """Pre-process the memory before actually query the memory. This should only be called once at the first invocation of call(). @@ -294,17 +428,30 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): Args: memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. - memory_mask: The boolean tensor with shape `[batch_size, max_time]`. For - any value equal to False, the corresponding value in memory should be - ignored. + memory_sequence_length (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros for + values past the respective sequence lengths. + memory_mask: (Optional) The boolean tensor with shape `[batch_size, + max_time]`. For any value equal to False, the corresponding value in + memory should be ignored. """ if self._memory_initialized: raise ValueError("The memory for the attention has already been setup.") + if memory_sequence_length is not None and memory_mask is not None: + raise ValueError("memory_sequence_length and memory_mask cannot be " + "used at same time for attention.") with ops.name_scope( self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): self.values = _prepare_memory( - memory, memory_mask=memory_mask, + memory, + memory_sequence_length=memory_sequence_length, + memory_mask=memory_mask, check_inner_dims_defined=self._check_inner_dims_defined) + # Mark the value as check since the memory and memory mask might not + # passed from __call__(), which does not have proper keras metadata. + # TODO(omalleyt): Remove this hack once the mask the has proper keras + # history. + base_layer_utils.mark_checked(self.values) if self.memory_layer is not None: self.keys = self.memory_layer(self.values) else: @@ -315,36 +462,25 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) or array_ops.shape(self.keys)[1]) if memory_mask is not None: - self.probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda - self.probability_fn(_maybe_mask_score( - score, self.score_mask_value, memory_mask=memory_mask), prev)) + unwrapped_probability_fn = self.probability_fn + def _mask_probability_fn(score, prev): + return unwrapped_probability_fn( + _maybe_mask_score( + score, + memory_mask=memory_mask, + memory_sequence_length=memory_sequence_length, + score_mask_value=self.score_mask_value), prev) + self.probability_fn = _mask_probability_fn self._memory_initialized = True - def call(self, inputs, mask=None, **kwargs): - """Base method to calculate the attention score. - - Args: - inputs: a list of tensor that contains `query`, `state`, and `memory`. - `query` is the tensor of dtype matching `memory` and shape - `[batch_size, query_depth]`. - `state` is the tensor of dtype matching `memory` and shape - `[batch_size, alignments_size]`. (`alignments_size` is memory's - `max_time`). - `memory` is the memory to query; usually the output of an RNN encoder. - This tensor should be shaped `[batch_size, max_time, feature]`. - mask: optional bool tensor with shape `[batch, max_time]` for the mask of - memory. If it is not None, the corresponding item of the memory should - be filtered out during calculation. - **kwargs: Dict, other keyword arguments for the call method. - """ - query, state, memory, memory_mask = self._process_inputs(inputs, mask) - if not self._memory_initialized: - self._setup_memory(memory, memory_mask=memory_mask) - return self.calculate_attention(query, state) - - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): raise NotImplementedError( - "calculate_attention need to be implemented by subclasses.") + "_calculate_attention need to be implemented by subclasses.") + + def compute_mask(self, inputs, mask=None): + # There real input of the attention is query and state, and the memory layer + # mask shouldn't be pass down. Returning None for all output mask here. + return None, None def get_config(self): config = {} @@ -361,16 +497,12 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): "class_name": self.memory_layer.__class__.__name__, "config": self.memory_layer.get_config(), } + # memory is a required init parameter and its a tensor. It cannot be + # serialized to config, so we put a placeholder for it. + config["memory"] = None base_config = super(_BaseAttentionMechanismV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) - def _process_inputs(self, inputs, mask): - if len(inputs) != 3: - raise ValueError( - "Expect to have 3 inputs for attention, got %d" % len(inputs)) - query, state, memory = inputs - return query, state, memory, mask - def _process_probability_fn(self, func_name): """Helper method to retrieve the probably function by string input.""" valid_probability_fns = { @@ -418,6 +550,46 @@ class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): def alignments_size(self): return self._alignments_size + @property + def state_size(self): + return self._alignments_size + + def initial_alignments(self, batch_size, dtype): + """Creates the initial alignment values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return a tensor of all zeros. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A `dtype` tensor shaped `[batch_size, alignments_size]` + (`alignments_size` is the values' `max_time`). + """ + max_time = self._alignments_size + return _zero_state_tensors(max_time, batch_size, dtype) + + def initial_state(self, batch_size, dtype): + """Creates the initial state values for the `AttentionWrapper` class. + + This is important for AttentionMechanisms that use the previous alignment + to calculate the alignment at the next time step (e.g. monotonic attention). + + The default behavior is to return the same output as initial_alignments. + + Args: + batch_size: `int32` scalar, the batch_size. + dtype: The `dtype`. + + Returns: + A structure of all-zero tensors with shapes as described by `state_size`. + """ + return self.initial_alignments(batch_size, dtype) + def _luong_score(query, keys, scale): """Implements Luong-style (multiplicative) scoring function. @@ -587,6 +759,8 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, scale=False, probability_fn="softmax", dtype=None, @@ -596,6 +770,11 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): Args: units: The depth of the attention mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. probability_fn: (optional) string, the name of function to convert the attention score to probabilities. The default is `softmax` which is @@ -618,26 +797,27 @@ class LuongAttentionV2(_BaseAttentionMechanismV2): if not memory_layer: memory_layer = layers.Dense( units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.scale = scale + self.scale_weight = None super(LuongAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, query_layer=None, memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.scale = scale def build(self, input_shape): super(LuongAttentionV2, self).build(input_shape) - if self.scale: + if self.scale and self.scale_weight is None: self.scale_weight = self.add_weight( "attention_g", initializer=init_ops.ones_initializer, shape=()) - else: - self.scale_weight = None self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -851,8 +1031,11 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, normalize=False, probability_fn="softmax", + kernel_initializer="glorot_uniform", dtype=None, name="BahdanauAttention", **kwargs): @@ -860,12 +1043,19 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. probability_fn: (optional) string, the name of function to convert the attention score to probabilities. The default is `softmax` which is `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within this module. Any other value will result into validation error. Default to use `softmax`. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. dtype: The data type for the query and memory layers of the attention mechanism. name: Name to use when creating ops. @@ -885,33 +1075,39 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): if not memory_layer: memory_layer = layers.Dense( units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.normalize = normalize + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_g = None + self.attention_b = None super(BahdanauAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, query_layer=query_layer, memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.normalize = normalize def build(self, input_shape): super(BahdanauAttentionV2, self).build(input_shape) - self.attention_v = self.add_weight( - "attention_v", [self.units], dtype=self.dtype) - if self.normalize: + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], + dtype=self.dtype, + initializer=self.kernel_initializer) + if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( "attention_g", initializer=init_ops.constant_initializer( math.sqrt((1. / self.units))), shape=()) self.attention_b = self.add_weight( "attention_b", shape=[self.units], initializer=init_ops.zeros_initializer()) - else: - self.attention_g = None - self.attention_b = None self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -940,6 +1136,7 @@ class BahdanauAttentionV2(_BaseAttentionMechanismV2): "units": self.units, "normalize": self.normalize, "probability_fn": self.probability_fn_name, + "kernel_initializer": initializers.serialize(self.kernel_initializer) } base_config = super(BahdanauAttentionV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1299,11 +1496,14 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, normalize=False, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + kernel_initializer="glorot_uniform", dtype=None, name="BahdanauMonotonicAttention", **kwargs): @@ -1311,6 +1511,11 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring for `_monotonic_probability_fn` for more information. @@ -1321,6 +1526,8 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + kernel_initializer: (optional), the name of the initializer for the + attention kernel. dtype: The data type for the query and memory layers of the attention mechanism. name: Name to use when creating ops. @@ -1341,32 +1548,39 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): if not memory_layer: memory_layer = layers.Dense( units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.normalize = normalize + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + self.kernel_initializer = initializers.get(kernel_initializer) + self.attention_v = None + self.attention_score_bias = None + self.attention_g = None + self.attention_b = None super(BahdanauMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, query_layer=query_layer, memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.normalize = normalize - self.sigmoid_noise = sigmoid_noise - self.sigmoid_noise_seed = sigmoid_noise_seed - self.score_bias_init = score_bias_init - self.mode = mode def build(self, input_shape): super(BahdanauMonotonicAttentionV2, self).build(input_shape) - self.attention_v = self.add_weight( - "attention_v", [self.units], dtype=self.dtype) - self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), dtype=self.dtype, - initializer=init_ops.constant_initializer( - self.score_bias_init, dtype=self.dtype)) - if not self.normalize: - self.attention_g = None - self.attention_b = None - else: + if self.attention_v is None: + self.attention_v = self.add_weight( + "attention_v", [self.units], dtype=self.dtype, + initializer=self.kernel_initializer) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), dtype=self.dtype, + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) + if self.normalize and self.attention_g is None and self.attention_b is None: self.attention_g = self.add_weight( "attention_g", dtype=self.dtype, initializer=init_ops.constant_initializer( @@ -1377,7 +1591,7 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): initializer=init_ops.zeros_initializer()) self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -1409,6 +1623,7 @@ class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): "sigmoid_noise_seed": self.sigmoid_noise_seed, "score_bias_init": self.score_bias_init, "mode": self.mode, + "kernel_initializer": initializers.serialize(self.kernel_initializer), } base_config = super(BahdanauMonotonicAttentionV2, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1542,6 +1757,8 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): def __init__(self, units, + memory, + memory_sequence_length=None, scale=False, sigmoid_noise=0., sigmoid_noise_seed=None, @@ -1554,6 +1771,11 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): Args: units: The depth of the query mechanism. + memory: The memory to query; usually the output of an RNN encoder. This + tensor should be shaped `[batch_size, max_time, ...]`. + memory_sequence_length: (optional): Sequence lengths for the batch entries + in memory. If provided, the memory tensor rows are masked with zeros + for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring for `_monotonic_probability_fn` for more information. @@ -1580,34 +1802,37 @@ class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): if not memory_layer: memory_layer = layers.Dense( units, name="memory_layer", use_bias=False, dtype=dtype) + self.units = units + self.scale = scale + self.sigmoid_noise = sigmoid_noise + self.sigmoid_noise_seed = sigmoid_noise_seed + self.score_bias_init = score_bias_init + self.mode = mode + self.attention_g = None + self.attention_score_bias = None super(LuongMonotonicAttentionV2, self).__init__( + memory=memory, + memory_sequence_length=memory_sequence_length, query_layer=None, memory_layer=memory_layer, probability_fn=wrapped_probability_fn, name=name, dtype=dtype, **kwargs) - self.units = units - self.scale = scale - self.sigmoid_noise = sigmoid_noise - self.sigmoid_noise_seed = sigmoid_noise_seed - self.score_bias_init = score_bias_init - self.mode = mode def build(self, input_shape): super(LuongMonotonicAttentionV2, self).build(input_shape) - if self.scale: + if self.scale and self.attention_g is None: self.attention_g = self.add_weight( "attention_g", initializer=init_ops.ones_initializer, shape=()) - else: - self.attention_g = None - self.attention_score_bias = self.add_weight( - "attention_score_bias", shape=(), - initializer=init_ops.constant_initializer( - self.score_bias_init, dtype=self.dtype)) + if self.attention_score_bias is None: + self.attention_score_bias = self.add_weight( + "attention_score_bias", shape=(), + initializer=init_ops.constant_initializer( + self.score_bias_init, dtype=self.dtype)) self.built = True - def calculate_attention(self, query, state): + def _calculate_attention(self, query, state): """Score the query based on the keys and values. Args: @@ -1695,7 +1920,15 @@ class AttentionWrapperState( def with_same_shape(old, new): """Check and set new tensor's shape.""" if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): - return tensor_util.with_same_shape(old, new) + if not context.executing_eagerly(): + return tensor_util.with_same_shape(old, new) + else: + if old.shape.as_list() != new.shape.as_list(): + raise ValueError("The shape of the AttentionWrapperState is " + "expected to be same as the one to clone. " + "self.shape: %s, input.shape: %s" % + (old.shape, new.shape)) + return new return new return nest.map_structure( @@ -1739,41 +1972,26 @@ def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None and memory_mask is None: - seq_len_mask = None - seq_len_batch_size = None + return memory elif memory_sequence_length is not None: seq_len_mask = array_ops.sequence_mask( memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_sequence_length.shape[0]) - or array_ops.shape(memory_sequence_length)[0]) else: # For memory_mask is not None - seq_len_mask = memory_mask - seq_len_batch_size = ( - tensor_shape.dimension_value(memory_mask.shape[0]) - or array_ops.shape(memory_mask)[0]) + seq_len_mask = math_ops.cast( + memory_mask, dtype=nest.flatten(memory)[0].dtype) def _maybe_mask(m, seq_len_mask): """Mask the memory based on the memory mask.""" rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - m_batch_size = tensor_shape.dimension_value( - m.shape[0]) or array_ops.shape(m)[0] - if seq_len_batch_size is not None: - message = ("memory_sequence_length and memory tensor batch sizes do not " - "match.") - with ops.control_dependencies([ - check_ops.assert_equal( - seq_len_batch_size, m_batch_size, message=message)]): - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask - else: - return m + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) @@ -1819,8 +2037,14 @@ def hardmax(logits, name=None): def _compute_attention(attention_mechanism, cell_output, attention_state, attention_layer): """Computes the attention and alignments for a given attention_mechanism.""" - alignments, next_attention_state = attention_mechanism( - cell_output, state=attention_state) + if isinstance(attention_mechanism, _BaseAttentionMechanismV2): + alignments, next_attention_state = attention_mechanism( + [cell_output, attention_state]) + else: + # For other class, assume they are following _BaseAttentionMechanism, which + # takes query and state as separate parameter. + alignments, next_attention_state = attention_mechanism( + cell_output, state=attention_state) # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] expanded_alignments = array_ops.expand_dims(alignments, 1) @@ -1833,13 +2057,13 @@ def _compute_attention(attention_mechanism, cell_output, attention_state, # the batched matmul is over memory_time, so the output shape is # [batch_size, 1, memory_size]. # we then squeeze out the singleton dim. - context = math_ops.matmul(expanded_alignments, attention_mechanism.values) - context = array_ops.squeeze(context, [1]) + context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) + context_ = array_ops.squeeze(context_, [1]) if attention_layer is not None: - attention = attention_layer(array_ops.concat([cell_output, context], 1)) + attention = attention_layer(array_ops.concat([cell_output, context_], 1)) else: - attention = context + attention = context_ return attention, alignments, next_attention_state diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 8f8f057702951094758b277ce060955f3dc6e99d..1d773a449890cd7335b2225db39d79ca958a3276 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -24,11 +24,12 @@ import numpy as np from tensorflow.contrib.seq2seq.python.ops import attention_wrapper from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.layers import base as layers_base +from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops @@ -182,11 +183,12 @@ def gather_tree_from_array(t, parent_ids, sequence_length): return ordered -def _check_maybe(t): +def _check_ndims(t): if t.shape.ndims is None: raise ValueError( "Expected tensor (%s) to have known rank, but ndims == None." % t) + def _check_static_batch_beam_maybe(shape, batch_size, beam_width): """Raises an exception if dimensions are known statically and can not be reshaped to [batch_size, beam_size, -1]. @@ -205,6 +207,7 @@ def _check_static_batch_beam_maybe(shape, batch_size, beam_width): return False return True + def _check_batch_beam(t, batch_size, beam_width): """Returns an Assert operation checking that the elements of the stacked TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point, @@ -229,70 +232,30 @@ def _check_batch_beam(t, batch_size, beam_width): return control_flow_ops.Assert(condition, [error_message]) +class BeamSearchDecoderMixin(object): + """BeamSearchDecoderMixin contains the common methods for BeamSearchDecoder. -class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder. - - **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in - `AttentionWrapper`, then you must ensure that: - - - The encoder output has been tiled to `beam_width` via - `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - - The `batch_size` argument passed to the `zero_state` method of this - wrapper is equal to `true_batch_size * beam_width`. - - The initial state created with `zero_state` above contains a - `cell_state` value containing properly tiled final state from the - encoder. - - An example: - - ``` - tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( - encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( - encoder_final_state, multiplier=beam_width) - tiled_sequence_length = tf.contrib.seq2seq.tile_batch( - sequence_length, multiplier=beam_width) - attention_mechanism = MyFavoriteAttentionMechanism( - num_units=attention_depth, - memory=tiled_inputs, - memory_sequence_length=tiled_sequence_length) - attention_cell = AttentionWrapper(cell, attention_mechanism, ...) - decoder_initial_state = attention_cell.zero_state( - dtype, batch_size=true_batch_size * beam_width) - decoder_initial_state = decoder_initial_state.clone( - cell_state=tiled_encoder_final_state) - ``` - - Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use - when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages - the translation to cover all inputs. + It is expected to be used a base class for concrete BeamSearchDecoder. Since + this is a mixin class, it is expected to be used together with other class as + base. """ def __init__(self, cell, - embedding, - start_tokens, - end_token, - initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, coverage_penalty_weight=0.0, - reorder_tensor_arrays=True): - """Initialize the BeamSearchDecoder. + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderMixin. Args: cell: An `RNNCell` instance. - embedding: A callable that takes a vector tensor of `ids` (argmax ids), - or the `params` argument for `embedding_lookup`. - start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. - end_token: `int32` scalar, the token that marks end of decoding. - initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. - output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., - `tf.layers.Dense`. Optional layer to apply to the RNN output prior - to storing the result or sampling. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0. @@ -302,59 +265,35 @@ class BeamSearchDecoder(decoder.Decoder): Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. + **kwargs: Dict, other keyword arguments for parent class. Raises: TypeError: if `cell` is not an instance of `RNNCell`, - or `output_layer` is not an instance of `tf.layers.Layer`. - ValueError: If `start_tokens` is not a vector or - `end_token` is not a scalar. + or `output_layer` is not an instance of `tf.keras.layers.Layer`. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and - not isinstance(output_layer, layers_base.Layer)): + not isinstance(output_layer, layers.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays - if callable(embedding): - self._embedding_fn = embedding - else: - self._embedding_fn = ( - lambda ids: embedding_ops.embedding_lookup(embedding, ids)) - - self._start_tokens = ops.convert_to_tensor( - start_tokens, dtype=dtypes.int32, name="start_tokens") - if self._start_tokens.get_shape().ndims != 1: - raise ValueError("start_tokens must be a vector") - self._end_token = ops.convert_to_tensor( - end_token, dtype=dtypes.int32, name="end_token") - if self._end_token.get_shape().ndims != 0: - raise ValueError("end_token must be a scalar") - - self._batch_size = array_ops.size(start_tokens) + self._start_tokens = None + self._end_token = None + self._batch_size = None self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._coverage_penalty_weight = coverage_penalty_weight - self._initial_cell_state = nest.map_structure( - self._maybe_split_batch_beams, initial_state, self._cell.state_size) - self._start_tokens = array_ops.tile( - array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) - self._start_inputs = self._embedding_fn(self._start_tokens) - - self._finished = array_ops.one_hot( - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=False, - off_value=True, - dtype=dtypes.bool) + super(BeamSearchDecoderMixin, self).__init__(**kwargs) @property def batch_size(self): return self._batch_size def _rnn_output_size(self): + """Get the output shape from the RNN layer.""" size = self._cell.output_size if self._output_layer is None: return size @@ -393,50 +332,6 @@ class BeamSearchDecoder(decoder.Decoder): predicted_ids=tensor_shape.TensorShape([self._beam_width]), parent_ids=tensor_shape.TensorShape([self._beam_width])) - @property - def output_dtype(self): - # Assume the dtype of the cell is the output_size structure - # containing the input_state's first component's dtype. - # Return that structure and int32 (the id) - dtype = nest.flatten(self._initial_cell_state)[0].dtype - return BeamSearchDecoderOutput( - scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), - predicted_ids=dtypes.int32, - parent_ids=dtypes.int32) - - def initialize(self, name=None): - """Initialize the decoder. - - Args: - name: Name scope for any created operations. - - Returns: - `(finished, start_inputs, initial_state)`. - """ - finished, start_inputs = self._finished, self._start_inputs - - dtype = nest.flatten(self._initial_cell_state)[0].dtype - log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) - array_ops.zeros([self._batch_size], dtype=dtypes.int32), - depth=self._beam_width, - on_value=ops.convert_to_tensor(0.0, dtype=dtype), - off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), - dtype=dtype) - init_attention_probs = get_attention_probs( - self._initial_cell_state, self._coverage_penalty_weight) - if init_attention_probs is None: - init_attention_probs = () - - initial_state = BeamSearchDecoderState( - cell_state=self._initial_cell_state, - log_probs=log_probs, - finished=finished, - lengths=array_ops.zeros( - [self._batch_size, self._beam_width], dtype=dtypes.int64), - accumulated_attention_probs=init_attention_probs) - - return (finished, start_inputs, initial_state) - def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. @@ -562,7 +457,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 1: return self._split_batch_beams(t, s) else: @@ -586,7 +481,7 @@ class BeamSearchDecoder(decoder.Decoder): """ if isinstance(t, tensor_array_ops.TensorArray): return t - _check_maybe(t) + _check_ndims(t) if t.shape.ndims >= 2: return self._merge_batch_beams(t, s) else: @@ -609,11 +504,18 @@ class BeamSearchDecoder(decoder.Decoder): if not isinstance(t, tensor_array_ops.TensorArray): return t # pylint: disable=protected-access - if (not t._infer_shape or not t._element_shape - or t._element_shape[0].ndims is None - or t._element_shape[0].ndims < 1): + # This is a bad hack due to the implementation detail of eager/graph TA. + # TODO(b/124374427): Update this to use public property of TensorArray. + if context.executing_eagerly(): + element_shape = t._element_shape + else: + element_shape = t._element_shape[0] + if (not t._infer_shape + or not t._element_shape + or element_shape.ndims is None + or element_shape.ndims < 1): shape = ( - t._element_shape[0] if t._infer_shape and t._element_shape + element_shape if t._infer_shape and t._element_shape else tensor_shape.TensorShape(None)) tf_logging.warn("The TensorArray %s in the cell state is not amenable to " "sorting based on the beam search result. For a " @@ -621,10 +523,10 @@ class BeamSearchDecoder(decoder.Decoder): "defined and have at least a rank of 1, but saw shape: %s" % (t.handle.name, shape)) return t - shape = t._element_shape[0] # pylint: enable=protected-access if not _check_static_batch_beam_maybe( - shape, tensor_util.constant_value(self._batch_size), self._beam_width): + element_shape, tensor_util.constant_value(self._batch_size), + self._beam_width): return t t = t.stack() with ops.control_dependencies( @@ -684,6 +586,359 @@ class BeamSearchDecoder(decoder.Decoder): return (beam_search_output, beam_search_state, next_inputs, finished) +class BeamSearchDecoder(BeamSearchDecoderMixin, decoder.Decoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoder to cover all inputs. + """ + + def __init__(self, + cell, + embedding, + start_tokens, + end_token, + initial_state, + beam_width, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True): + """Initialize the BeamSearchDecoder. + + Args: + cell: An `RNNCell` instance. + embedding: A callable that takes a vector tensor of `ids` (argmax ids), + or the `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + beam_width: Python integer, the number of beams. + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + ValueError: If `start_tokens` is not a vector or + `end_token` is not a scalar. + """ + super(BeamSearchDecoder, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays) + + if callable(embedding): + self._embedding_fn = embedding + else: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + def initialize(self, name=None): + """Initialize the decoder. + + Args: + name: Name scope for any created operations. + + Returns: + `(finished, start_inputs, initial_state)`. + """ + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + +class BeamSearchDecoderV2(BeamSearchDecoderMixin, decoder.BaseDecoder): + # Note that the inheritance hierarchy is important here. The Mixin has to be + # the first parent class since we will use super().__init__(), and Mixin which + # is a object will properly invoke the __init__ method of other parent class. + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + + Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use + when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages + the decoding to cover all inputs. + """ + + def __init__(self, + cell, + beam_width, + embedding_fn=None, + output_layer=None, + length_penalty_weight=0.0, + coverage_penalty_weight=0.0, + reorder_tensor_arrays=True, + **kwargs): + """Initialize the BeamSearchDecoderV2. + + Args: + cell: An `RNNCell` instance. + beam_width: Python integer, the number of beams. + embedding_fn: A callable that takes a vector tensor of `ids` (argmax ids). + output_layer: (Optional) An instance of `tf.keras.layers.Layer`, i.e., + `tf.keras.layers.Dense`. Optional layer to apply to the RNN output + prior to storing the result or sampling. + length_penalty_weight: Float weight to penalize length. Disabled with 0.0. + coverage_penalty_weight: Float weight to penalize the coverage of source + sentence. Disabled with 0.0. + reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell + state will be reordered according to the beam search path. If the + `TensorArray` can be reordered, the stacked form will be returned. + Otherwise, the `TensorArray` will be returned as is. Set this flag to + `False` if the cell state contains `TensorArray`s that are not amenable + to reordering. + **kwargs: Dict, other keyword arguments for initialization. + + Raises: + TypeError: if `cell` is not an instance of `RNNCell`, + or `output_layer` is not an instance of `tf.keras.layers.Layer`. + """ + super(BeamSearchDecoderV2, self).__init__( + cell, + beam_width, + output_layer=output_layer, + length_penalty_weight=length_penalty_weight, + coverage_penalty_weight=coverage_penalty_weight, + reorder_tensor_arrays=reorder_tensor_arrays, + **kwargs) + + if embedding_fn is None or callable(embedding_fn): + self._embedding_fn = embedding_fn + else: + raise ValueError("embedding_fn is expected to be a callable, got %s" % + type(embedding_fn)) + + def initialize(self, + embedding, + start_tokens, + end_token, + initial_state): + """Initialize the decoder. + + Args: + embedding: A tensor from the embedding layer output, which is the + `params` argument for `embedding_lookup`. + start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. + end_token: `int32` scalar, the token that marks end of decoding. + initial_state: A (possibly nested tuple of...) tensors and TensorArrays. + Returns: + `(finished, start_inputs, initial_state)`. + Raises: + ValueError: If `start_tokens` is not a vector or `end_token` is not a + scalar. + """ + if embedding is not None and self._embedding_fn is not None: + raise ValueError( + "embedding and embedding_fn cannot be provided at same time") + elif embedding is not None: + self._embedding_fn = ( + lambda ids: embedding_ops.embedding_lookup(embedding, ids)) + + self._start_tokens = ops.convert_to_tensor( + start_tokens, dtype=dtypes.int32, name="start_tokens") + if self._start_tokens.get_shape().ndims != 1: + raise ValueError("start_tokens must be a vector") + self._end_token = ops.convert_to_tensor( + end_token, dtype=dtypes.int32, name="end_token") + if self._end_token.get_shape().ndims != 0: + raise ValueError("end_token must be a scalar") + + self._batch_size = array_ops.size(start_tokens) + self._initial_cell_state = nest.map_structure( + self._maybe_split_batch_beams, initial_state, self._cell.state_size) + self._start_tokens = array_ops.tile( + array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) + self._start_inputs = self._embedding_fn(self._start_tokens) + + self._finished = array_ops.one_hot( + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=False, + off_value=True, + dtype=dtypes.bool) + + finished, start_inputs = self._finished, self._start_inputs + + dtype = nest.flatten(self._initial_cell_state)[0].dtype + log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) + array_ops.zeros([self._batch_size], dtype=dtypes.int32), + depth=self._beam_width, + on_value=ops.convert_to_tensor(0.0, dtype=dtype), + off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype), + dtype=dtype) + init_attention_probs = get_attention_probs( + self._initial_cell_state, self._coverage_penalty_weight) + if init_attention_probs is None: + init_attention_probs = () + + initial_state = BeamSearchDecoderState( + cell_state=self._initial_cell_state, + log_probs=log_probs, + finished=finished, + lengths=array_ops.zeros( + [self._batch_size, self._beam_width], dtype=dtypes.int64), + accumulated_attention_probs=init_attention_probs) + + return (finished, start_inputs, initial_state) + + @property + def output_dtype(self): + # Assume the dtype of the cell is the output_size structure + # containing the input_state's first component's dtype. + # Return that structure and int32 (the id) + dtype = nest.flatten(self._initial_cell_state)[0].dtype + return BeamSearchDecoderOutput( + scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), + predicted_ids=dtypes.int32, + parent_ids=dtypes.int32) + + def call(self, embeddning, start_tokens, end_token, initial_state, **kwargs): + init_kwargs = kwargs + init_kwargs["start_tokens"] = start_tokens + init_kwargs["end_token"] = end_token + init_kwargs["initial_state"] = initial_state + return decoder.dynamic_decode(self, + output_time_major=self.output_time_major, + impute_finished=self.impute_finished, + maximum_iterations=self.maximum_iterations, + parallel_iterations=self.parallel_iterations, + swap_memory=self.swap_memory, + decoder_init_input=embeddning, + decoder_init_kwargs=init_kwargs) + + def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight, coverage_penalty_weight): @@ -1068,7 +1323,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, """ if isinstance(gather_from, tensor_array_ops.TensorArray): return gather_from - _check_maybe(gather_from) + _check_ndims(gather_from) if gather_from.shape.ndims >= len(gather_shape): return _tensor_gather_helper( gather_indices=gather_indices, diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index f88b03ec4c2b1f250091594ea12d7d1862029fa2..7dd52df6b68caea6111813837ba1e872acbeccdb 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -4,17 +4,14 @@ exports_files([ "LICENSE", ]) -load( - "//tensorflow:tensorflow.bzl", - "py_test", - "tf_gen_op_wrapper_py", -) +load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "summary_ops_test", srcs = ["summary_ops_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:errors", @@ -22,7 +19,6 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", @@ -35,6 +31,7 @@ py_test( srcs = ["summary_ops_graph_test.py"], srcs_version = "PY2AND3", deps = [ + ":summary", ":summary_test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -43,7 +40,6 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python:variables", "@six_archive//:six", diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py index 807741e05f92f6b666c175269742dc1af50c0054..8e13f7f56b23e47f046120b285b1519c6371ddab 100644 --- a/tensorflow/contrib/summary/summary_ops_graph_test.py +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -22,6 +22,7 @@ import time import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -32,7 +33,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 10e4556dacbc17ec02c2bd698389b04d517d7076..27bfdeb3601f4fdb9897feee509b06d5e8f9b873 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -25,6 +25,7 @@ import sqlite3 import numpy as np import six +from tensorflow.contrib.summary import summary as summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 @@ -36,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 398ac314f4b520610ec100273b37c33bc4b5b43a..583bbf97c57cf263f65bc3b0a56b32cc2dce5482 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -537,8 +537,9 @@ py_library( py_test( name = "random_forest_test", - size = "large", + size = "medium", srcs = ["client/random_forest_test.py"], + shard_count = 6, srcs_version = "PY2AND3", tags = [ "noasan", diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index b9aad36f3d25b9fb7b8b525be54fb7a39394b373..76b1d2b4da269cda71f5b49878f2933d7d9b5776 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -304,7 +304,7 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + auto traverse = [&set_leaf_ids, &data_set, decision_tree_resource, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index fe2c91c1047fe56710b1a86b2fa3206caf6ff3bc..0243f106814511c1b53a5aacb830b845214a00a3 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -307,7 +307,7 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, + auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); @@ -317,7 +317,7 @@ class ProcessInputOp : public OpKernel { static_cast(end), &ready_to_split); }; - auto update_collated = [this, &target, &num_targets, fertile_stats_resource, + auto update_collated = [&target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 290c16fe3966791ea78986539750caf938a37322..40bf7081a3f22dfd68fd46f0f61695ee9ca7863b 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -35,7 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _model_ops = loader.load_op_library( diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 9184198cd4c8fd2a7609714d094d5ef2b6868658..80afcfb251f4d6455a9eb8ba5df4a6e43d2feb1c 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.tracking import tracking _stats_ops = loader.load_op_library( diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 42bf1eda5179d0f72f4fd8432e6b5684f8e46917..4a959378138dec6f1c1a3f490704d7aebeae9b47 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -21,20 +21,21 @@ from __future__ import print_function from tensorflow.python.compiler.tensorrt import trt_convert -def create_inference_graph(input_graph_def, - outputs, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=trt_convert.TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batches=None, - use_calibration=True, - input_saved_model_dir=None, - input_saved_model_tags=None, - output_saved_model_dir=None, - session_config=None): +def create_inference_graph( + input_graph_def, + outputs, + max_batch_size=1, + max_workspace_size_bytes=trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES, + precision_mode=trt_convert.TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=None, + use_calibration=True, + input_saved_model_dir=None, + input_saved_model_tags=None, + output_saved_model_dir=None, + session_config=None): return trt_convert.create_inference_graph( input_graph_def=input_graph_def, outputs=outputs, diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index d1be31ddc799ce4c4ef9baa15729fde7925f2f6c..4ba814b9e3d3621f9ab924961e2740885fa93b33 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -161,7 +161,10 @@ py_test( ], shard_count = 10, srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], # b/63391119 + tags = [ + "no_pip_gpu", # b/63391119 + "notap", # b/124520733 + ], deps = [ ":estimators", ":feature_keys", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 1859dee9d08ac4a8f3f496222d537b622c65621e..7c1661d20f15f94a929a46dafc79d59ca73e53cb 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -23,17 +23,12 @@ package( ], ) -cc_library( - name = "all_ops", +py_library( + name = "tpu_py", + srcs = ["python/ops/tpu_ops.py"], + srcs_version = "PY2AND3", deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", + "//tensorflow/python/tpu:tpu_py", ], ) @@ -42,19 +37,7 @@ py_library( srcs = ["python/tpu/async_checkpoint.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/tpu:async_checkpoint", ], ) @@ -75,145 +58,20 @@ py_library( ":functional", ":tpu_embedding", ":tpu_lib", - ":tpu_ordinal_selector_py", "//tensorflow/contrib/training:training_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:function", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:summary_ops_v2", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:util", - "@six_archive//:six", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "cross_replica_ops", - "heartbeat_ops", - "host_compute_ops", - "infeed_ops", - "outfeed_ops", - "replication_ops", - "tpu_configuration_ops", - "tpu_embedding_ops", - "tpu_ordinal_selector_op", - "functional_ops", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ops.so", - srcs = [ - "ops/cross_replica_ops.cc", - "ops/heartbeat_ops.cc", - "ops/host_compute_ops.cc", - "ops/infeed_ops.cc", - "ops/outfeed_ops.cc", - "ops/replication_ops.cc", - "ops/tpu_configuration_ops.cc", - "ops/tpu_embedding_ops.cc", - ], - deps = [ - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc", - "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils", - "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils", - "//tensorflow/core:lib_proto_parsing", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ops", - hidden = [ - "SendTPUEmbeddingGradients", - "EnqueueTPUEmbeddingIntegerBatch", - "EnqueueTPUEmbeddingSparseBatch", - "EnqueueTPUEmbeddingSparseTensorBatch", - ], - deps = [ - ":cross_replica_ops_op_lib", - ":heartbeat_ops_op_lib", - ":host_compute_ops_op_lib", - ":infeed_ops_op_lib", - ":outfeed_ops_op_lib", - ":replication_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ], -) - -tf_custom_op_library( - name = "python/ops/_tpu_ordinal_selector_op.so", - srcs = ["ops/tpu_ordinal_selector_op.cc"], -) - -tf_custom_op_py_library( - name = "tpu_ordinal_selector_py", - srcs = ["python/ops/tpu_ordinal_selector_op.py"], - dso = [":python/ops/_tpu_ordinal_selector_op.so"], - kernels = [ - ":tpu_ordinal_selector_op_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":tpu_ordinal_selector_op", - ], -) - -tf_gen_op_wrapper_py( - name = "tpu_ordinal_selector_op", - deps = [ - ":tpu_ordinal_selector_op_op_lib", + "//tensorflow/python/tpu:tpu_estimator", ], ) -tf_custom_op_library( - name = "python/ops/_functional_ops.so", - srcs = ["ops/functional_ops.cc"], -) - -tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "python/tpu/gen_functional_ops.py", - hidden = [ - "TPUPartitionedCall", - ], - deps = [":functional_ops_op_lib"], -) - -tf_custom_op_py_library( +py_library( name = "functional", srcs = ["python/tpu/functional.py"], - dso = [":python/ops/_functional_ops.so"], - kernels = [ - ":functional_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = [ "//visibility:public", ], deps = [ - ":gen_functional_ops", + "//tensorflow/python/tpu:functional", ], ) @@ -222,30 +80,7 @@ py_library( srcs = ["python/profiler/__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_pb2_grpc", - "//tensorflow/contrib/tpu/profiler:tpu_profiler_analysis_proto_py", - "//tensorflow/contrib/tpu/profiler:trace_events_proto_py", - "//tensorflow/python:util", - ], -) - -tf_custom_op_py_library( - name = "tpu_py", - srcs = ["python/ops/tpu_ops.py"], - dso = [":python/ops/_tpu_ops.so"], - kernels = [ - ":all_ops", - ], - srcs_version = "PY2AND3", - deps = [ - ":profiler", - ":tpu_ops", - "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow/python/tpu/profiler", ], ) @@ -262,6 +97,7 @@ py_library( ":tpu_embedding", ":tpu_estimator", ":tpu_lib", + "//tensorflow/python/tpu", ], ) @@ -284,8 +120,8 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/distribute", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/core:protos_all_py", + "//tensorflow/core/protobuf/tpu:compilation_result_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -325,30 +161,12 @@ py_library( srcs_version = "PY2AND3", deps = [ ":datasets", + ":functional", ":profiler", ":tpu_py", - "//tensorflow/compiler/xla/experimental/xla_sharding", - "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/compiler:xla", - "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", - "//tensorflow/contrib/tpu/proto:dynamic_padding_proto_py", - "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py", - "//tensorflow/contrib/tpu/proto:topology_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:control_flow_util", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", + "//tensorflow/python/tpu:tpu_lib", ], ) @@ -359,125 +177,20 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", - ], -) - -tf_py_test( - name = "datasets_test", - size = "medium", - srcs = ["python/tpu/datasets_test.py"], - additional_deps = [ - "//tensorflow/python:client_testlib", - ":datasets", - ], - grpc_enabled = True, - shard_count = 4, - tags = ["no_oss"], -) - -tf_py_test( - name = "tpu_test", - size = "small", - srcs = ["python/tpu/tpu_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:layers", - ], - tags = ["no_windows"], # TODO: needs investigation on Windows -) - -tf_py_test( - name = "tpu_sharding_test", - size = "small", - srcs = ["python/tpu/tpu_sharding_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "bfloat16_test", - size = "small", - srcs = ["python/tpu/bfloat16_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - ], -) - -tf_py_test( - name = "tpu_infeed_test", - size = "small", - srcs = ["python/tpu/tpu_infeed_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_config_test", - size = "small", - srcs = ["python/tpu/tpu_config_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "tpu_estimator_signals_test", - size = "small", - srcs = ["python/tpu/tpu_estimator_signals_test.py"], - additional_deps = [ - ":tpu_estimator", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - ], -) - -tf_py_test( - name = "topology_test", - size = "medium", - srcs = ["python/tpu/topology_test.py"], - additional_deps = [ - ":tpu", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/tpu:datasets", ], ) py_library( name = "tpu_embedding", - srcs = ["python/tpu/tpu_embedding.py"], + srcs = [ + "python/tpu/tpu_embedding.py", + "python/tpu/tpu_embedding_gradient.py", + ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_ops", - "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@six_archive//:six", + "//tensorflow/python/tpu:tpu_embedding", ], ) @@ -486,31 +199,6 @@ py_library( srcs = ["python/tpu/feature_column.py"], deps = [ ":tpu_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", - ], -) - -tf_py_test( - name = "feature_column_test", - srcs = [ - "python/tpu/feature_column_test.py", - ], - additional_deps = [ - ":feature_column", - "//third_party/py/numpy", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/tpu:feature_column", ], - main = "python/tpu/feature_column_test.py", ) diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 541fbf33a302a4d850422885fdbbc438bd6b9b7b..9fb29f5e17bde9864a7e9b85d4abc2f6a4681b42 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -2,35 +2,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos") - -tf_proto_library( - name = "tpu_profiler_proto", - srcs = ["tpu_profiler.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":op_profile_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -cc_library( - name = "dump_tpu_profile", - srcs = ["dump_tpu_profile.cc"], - hdrs = ["dump_tpu_profile.h"], - visibility = ["//visibility:public"], - deps = [ - ":op_profile_proto_cc", - ":tpu_profiler_proto_cc", - ":trace_events_proto_cc", - ":trace_events_to_json", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) cc_library( name = "version", @@ -45,69 +16,9 @@ tf_cc_binary( ], visibility = ["//visibility:public"], deps = [ - ":dump_tpu_profile", - ":tpu_profiler_analysis_proto_cc", - ":tpu_profiler_proto_cc", ":version", - "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/profiler/rpc/client:capture_profile", ], ) - -tf_proto_library( - name = "trace_events_proto", - srcs = ["trace_events.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -cc_library( - name = "trace_events_to_json", - srcs = ["trace_events_to_json.cc"], - hdrs = ["trace_events_to_json.h"], - deps = [ - ":trace_events_proto_cc", - "//tensorflow/core:lib", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_cc_test( - name = "trace_events_to_json_test", - srcs = ["trace_events_to_json_test.cc"], - deps = [ - ":trace_events_to_json", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@jsoncpp_git//:jsoncpp", - ], -) - -tf_proto_library( - name = "op_profile_proto", - srcs = ["op_profile.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "tpu_profiler_analysis_proto", - srcs = ["tpu_profiler_analysis.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - protodeps = [":tpu_profiler_proto"] + tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -py_library( - name = "tpu_profiler_analysis_pb2_grpc", - srcs = ["tpu_profiler_analysis_pb2_grpc.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [":tpu_profiler_analysis_proto_py"], -) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 1c5ea2d997a58ca57ddc212ffd56aad525e961da..f11d1a9f37eeb19b95a876bd68575022e6b91521 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,235 +18,11 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpcpp/grpcpp.h" - -#include -#include -#include - -#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" -#include "tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.grpc.pb.h" #include "tensorflow/contrib/tpu/profiler/version.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/util/command_line_flags.h" -namespace tensorflow { -namespace tpu { -namespace { - -using ::tensorflow::TPUProfileAnalysis; -using ::tensorflow::TPUProfiler; - -constexpr uint64 kMaxEvents = 1000000; - -string GetCurrentTimeStampAsString() { - char s[128]; - std::time_t t = std::time(nullptr); - CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0); - return s; -} - -Status ValidateHostPortPair(const string& host_port) { - uint32 port; - std::vector parts = str_util::Split(host_port, ':'); - // Must be host:port, port must be a number, host must not contain a '/', - // host also must not be empty. - if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || - parts[0].find("/") != string::npos || parts[0].empty()) { - return errors::InvalidArgument("Could not interpret \"", host_port, - "\" as a host-port pair."); - } - return Status::OK(); -} - -ProfileRequest PopulateProfileRequest(int duration_ms, - const string& repository_root, - const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request; - request.set_duration_ms(duration_ms); - request.set_max_events(kMaxEvents); - if (tensorflow::str_util::StartsWith(repository_root, "gs://")) { - // For backward compatibilities, only generate tracetable etc when the - // user provide a GCS path for model directory. - request.set_repository_root(repository_root); - request.set_session_id(session_id); - } - request.add_tools("op_profile"); - request.add_tools("input_pipeline"); - request.add_tools("memory_viewer"); - request.add_tools("overview_page"); - *request.mutable_opts() = opts; - return request; -} - -// Returns whether the returned trace is empty. -// Failure are handled by CHECK, i.e. abort() -bool Profile(const string& service_addr, const string& logdir, int duration_ms, - const string& repository_root, const string& session_id, - const ProfileOptions& opts) { - ProfileRequest request = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - ProfileResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); - - if (!response.encoded_trace().empty()) { - TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( - logdir, session_id, "", response, &std::cout)); - // Print this at the end so that it's not buried in irrelevant LOG messages. - std::cout - << "NOTE: using the trace duration " << duration_ms << "ms." - << std::endl - << "Set an appropriate duration (with --duration_ms) if you " - "don't see a full step in your trace or the captured trace is too " - "large." - << std::endl; - } - - return response.encoded_trace().empty(); -} - -// Start a new profiling session that include all the hosts included in -// hostnames, for the time interval of duration_ms. Possibly save the profiling -// result in the directory specified by repository_root and session_id. -bool NewSession(const string& service_addr, - const std::vector& hostnames, - int duration_ms, const string& repository_root, - const string& session_id, const ProfileOptions& opts) { - NewProfileSessionRequest new_session_request; - *new_session_request.mutable_request() = - PopulateProfileRequest(duration_ms, repository_root, session_id, opts); - new_session_request.set_repository_root(repository_root); - new_session_request.set_session_id(session_id); - for (const auto& hostname : hostnames) { - new_session_request.add_hosts(hostname); - } - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their - // `ValidateHostPortPair` checks for empty host string case. - channel_args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - // TODO(jiesun): GRPC support following relevant naming scheme: - // 1. dns:///host:port - // 2. ipv4:host:port or ipv6:[host]:port - // We might need to change the prefix which depends on what TPU name resolver - // will give us. - std::unique_ptr stub = - TPUProfileAnalysis::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - NewProfileSessionResponse new_session_response; - TF_QCHECK_OK(FromGrpcStatus( - stub->NewSession(&context, new_session_request, &new_session_response))); - - std::cout << "Profile session succeed for host(s):" - << str_util::Join(hostnames, ",") << std::endl; - return new_session_response.empty_trace(); -} - -// Starts tracing on a single or multiple TPU hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -void StartTracing(const tensorflow::string& service_addr, - const tensorflow::string& logdir, - const tensorflow::string& workers_list, - bool include_dataset_ops, int duration_ms, - int num_tracing_attempts) { - // Use the current timestamp as the run name. - tensorflow::string session_id = GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - io::JoinPath(logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(workers_list, ","); - - bool empty_trace = false; - int remaining_attempts = num_tracing_attempts; - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(include_dataset_ops); - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, - repository_root, session_id, opts); - } else { - tensorflow::string tpu_master = service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; - } - - if (empty_trace) { - std::cout << "No trace event is collected after " << num_tracing_attempts - << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." - << std::endl; - } -} - -MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { - MonitorRequest request; - request.set_duration_ms(duration_ms); - request.set_monitoring_level(monitoring_level); - return request; -} - -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). -void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, int num_queries) { - for (int query = 0; query < num_queries; ++query) { - MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level); - - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - TPUProfiler::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - MonitorResponse response; - TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); - - std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 - << "):\n\n" - << response.data() << std::flush; - } -} - -} // namespace -} // namespace tpu -} // namespace tensorflow - int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; @@ -300,8 +76,9 @@ int main(int argc, char** argv) { std::cout << usage.c_str() << std::endl; return 2; } - tensorflow::Status status = - tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); + tensorflow::Status status; + status = + tensorflow::profiler::client::ValidateHostPortPair(FLAGS_service_addr); if (!status.ok()) { std::cout << status.error_message() << std::endl; std::cout << usage.c_str() << std::endl; @@ -324,12 +101,17 @@ int main(int argc, char** argv) { << FLAGS_service_addr << " for " << duration_ms << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, - FLAGS_monitoring_level, num_queries); + tensorflow::profiler::client::StartMonitoring( + FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); } else { - tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, - FLAGS_workers_list, FLAGS_include_dataset_ops, - duration_ms, num_tracing_attempts); + status = tensorflow::profiler::client::StartTracing( + FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, + FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts); + if (!status.ok()) { + std::cout << status.error_message() << std::endl; + std::cout << usage.c_str() << std::endl; + return 2; + } } return 0; } diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 500dd2cd39d6b8747cebb95d0a01d8c5680427fe..8605bae5c128513186d8c03835dcf49d3e4b6fd9 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -1,394 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Operations for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform - -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops import gen_tpu_ops - from tensorflow.contrib.tpu.ops.gen_tpu_ops import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ops.so")) - - def _create_default_group_assignment(): - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "cross_replica_sum should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - group_assignment = [list(range(num_shards))] - return group_assignment - - def all_to_all(x, - concat_dimension, - split_dimension, - split_count, - group_assignment=None, - name=None): - """Exchange data across TPU replicas. - - Args: - x: The local tensor. - concat_dimension: The dimension number to concatenate. - split_dimension: The dimension number to split. - split_count: The number of splits, this number must equal to the sub-group - size(group_assignment.get_shape()[1]) - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is concatenated by data from different replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - return gen_tpu_ops.all_to_all( - x, - group_assignment, - concat_dimension=concat_dimension, - split_dimension=split_dimension, - split_count=split_count, - name=name) - - @ops.RegisterGradient("AllToAll") - def _all_to_all_grad(op, grad): - # The gradient of a all-to-all is also a all-to-all but the - # split_dimension and concat_dimension is swapped. - # The graident with respect to group_assignment is None. - return [ - gen_tpu_ops.all_to_all( - grad, - op.inputs[1], - concat_dimension=op.get_attr("split_dimension"), - split_dimension=op.get_attr("concat_dimension"), - split_count=op.get_attr("split_count")), None - ] - - def cross_replica_sum(x, group_assignment=None, name=None): - """Sum the input tensor across replicas according to group_assignment. - - Args: - x: The local tensor to the sum. - group_assignment: Optional 2d int32 lists with shape [num_groups, - num_replicas_per_group]. `group_assignment[i]` represents the replica - ids in the ith subgroup. - name: Optional op name. - - Returns: - A `Tensor` which is summed across replicas. - """ - if group_assignment is None: - group_assignment = _create_default_group_assignment() - - return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) - - def collective_permute(x, source_target_pairs, name=None): - """Permute the input tensor across replicas given source_target_pairs. - - For each source_target_pair , we send replica a's input to replica b. - Each replica id must only appear once in the source column. Also it must - only appear once in the target column. - For the replica id not in the target column, this op returns a zero tensor - with the same shape and dtype of the input x. - - For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing - source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: - `[0, A, B, C]`. - - Args: - x: The local tensor to be permuted. - source_target_pairs: 2d int lists with shape [num_pairs, 2]. - source_target_pairs[i][0] represents the source replica id and - source_target_pairs[i][1] represents the target replica id. - name: Optional op name. - - Returns: - A `Tensor` which is permuted. - """ - return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) - - @ops.RegisterGradient("CollectivePermute") - def _collective_permute_grad(op, grad): - # The gradient of a collective permute operation is also a collective - # permute, but with source/target pairs reversed. The gradient with respect - # to input argument `source_target_pairs` is `None`. - source_target_pairs = op.inputs[1][:, ::-1] - return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] - - @ops.RegisterGradient("CrossReplicaSum") - def _cross_replica_sum_grad(op, grad): - # The gradient of a cross replica sum is also a cross-replica sum. - # The gradient with respect to group_assignment is None. - return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] - - # This extra type checking exists to give a more helpful error message in - # the common case that uint8 and int64 values are infed. Remove when both - # types are supported. - - _SUPPORTED_INFEED_DTYPES = set([ - dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, - dtypes.complex64, dtypes.uint32 - ]) - - def infeed_dequeue(dtype, shape, name=None): - """A placeholder op for a value that will be fed into the computation. - - Args: - dtype: A `tf.DType`. The type of elements in the tensor. - shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. - name: A name for the operation (optional). - - Returns: - A `Tensor` of type `dtype`. - A tensor that will be provided using the infeed mechanism. - - Raises: - TypeError: If 'dtype` is not a supported infeed type. - """ - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - - return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) - - # pylint: disable=redefined-outer-name - def infeed_dequeue_tuple(dtypes, shapes, name=None): - """A placeholder op for values fed into the TPU simultaneously as a tuple. - - Args: - dtypes: A list of `tf.DType`s that has length `>= 1`. - The element types of each element in `outputs`. - shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). - The shapes of each tensor in `outputs`. - name: A name for the operation (optional). - - Returns: - A list of `Tensor` objects of type `dtypes`. - A list of tensors that will be provided using the infeed mechanism. - - Raises: - TypeError: If a type in 'dtypes` is not a supported infeed type. - """ - for dtype in dtypes: - if dtype not in _SUPPORTED_INFEED_DTYPES: - raise TypeError( - "{} is not a supported TPU infeed type. Supported types are: " - "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) - return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) - # pylint: enable=redefined-outer-name - - # pylint: disable=protected-access - def send_tpu_embedding_gradients(inputs, - config, - learning_rates=None, - name=None): - """A placeholder op for feeding per-sample gradients to the embedding layer. - - Args: - inputs: A TensorList of gradients with which to update embedding tables. - This argument has the same length and shapes as the return value of - RecvTPUEmbeddingActivations, but contains gradients of the model's - loss with respect to the embedding activations. The embedding tables - are updated from these gradients via the optimizers specified in the - TPU embedding configuration given to tpu.initialize_system. - config: Serialized TPUEmbeddingConfiguration proto. - learning_rates: A TensorList of float32 scalars, one for each dynamic - learning rate tag: see the comments in - //third_party/tensorflow/contrib/tpu/proto/ - optimization_parameters.proto. - Multiple tables can share the same dynamic learning rate tag as - specified in the configuration. If the learning rates for all tables - are constant, this list should be empty. - name: A name for the operation (optional). - - Returns: - A SendTPUEmbeddingGradients operation. - """ - if learning_rates is None: - learning_rates = [] - return gen_tpu_ops._send_tpu_embedding_gradients( - inputs=inputs, learning_rates=learning_rates, config=config, name=name) - - - send_tpu_embedding_gradients.__doc__ = ( - gen_tpu_ops._send_tpu_embedding_gradients.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_integer_batch(batch, - device_ordinal, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - batch: A list of 1D tensors, one for each embedding table, containing the - indices into the tables. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingIntegerBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_integer_batch( - batch=batch, - device_ordinal=device_ordinal, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_integer_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_batch(sample_indices, - embedding_indices, - aggregation_weights, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - and feature to which the corresponding embedding_indices and - aggregation_weights values belong. sample_indices[i] must equal b * nf + - f, where nf is the number of features from the corresponding table, f is - in [0, nf), and b is in [0, batch size). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. - aggregation_weights: A list of rank 1 Tensors containing per sample -- - i.e. per (training example, feature) -- aggregation weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__) - - # pylint: disable=protected-access - def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, - embedding_indices, - aggregation_weights, - table_ids, - device_ordinal, - combiners=None, - mode_override=None, - name=None): - """A placeholder op for enqueueing embedding IDs to the TPU. - - Args: - sample_indices: A list of rank 1 Tensors specifying the training example - to which the corresponding embedding_indices and aggregation_weights - values belong. It corresponds to sp_ids.indices[:,0] in - embedding_lookup_sparse(). - embedding_indices: A list of rank 1 Tensors, indices into the embedding - tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). - aggregation_weights: A list of rank 1 Tensors containing per training - example aggregation weights. It corresponds to sp_weights.values in - embedding_lookup_sparse(). - table_ids: A list of integers specifying the identifier of the embedding - table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to - lookup the corresponding input. The ith input is looked up using - table_ids[i]. The size of the table_ids list must be equal to that of - sample_indices, embedding_indices and aggregation_weights. - device_ordinal: The TPU device to use. Should be >= 0 and less than the - number of TPU cores in the task on which the node is placed. - combiners: A list of string scalars, one for each embedding table that - specify how to normalize the embedding activations after weighted - summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is - invalid to have the sum of the weights be 0 for 'mean' or the sum of the - squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default - is to use 'sum' for all tables (optional). - mode_override: A string input that overrides the mode specified in the - TPUEmbeddingConfiguration. Supported values are {'unspecified', - 'inference', 'training', 'backward_pass_only'}. When set to - 'unspecified', the mode set in TPUEmbeddingConfiguration is used, - otherwise mode_override is used (optional). - name: A name for the operation (optional). - - Returns: - An EnqueueTPUEmbeddingSparseTensorBatch operation. - """ - if mode_override is None: - mode_override = "unspecified" - return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch( - sample_indices=sample_indices, - embedding_indices=embedding_indices, - aggregation_weights=aggregation_weights, - table_ids=table_ids, - device_ordinal=device_ordinal, - combiners=combiners, - mode_override=mode_override, - name=name) - - enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( - gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__) - -else: - # We have already built the appropriate libraries into the binary via CMake - # if we have built contrib, so we don't need this - pass +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ops import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py index 5ca38cd1bae5753a7398834bd96d3b26e66b4941..788e1fe0568cf2f406c379e4d928100ea51a37a3 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py @@ -1,38 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Operations to select TPU core to run.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_ordinal_selector_op = loader.load_op_library( - resource_loader.get_path_to_datafile("_tpu_ordinal_selector_op.so")) - -else: - # We have already built the appropriate libraries into the binary via CMake - # if we have built contrib, so we don't need this - pass +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.ops.tpu_ordinal_selector_op import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py index 15ce6aceec299adacd7025f0021cf8b6f6ef765b..aeb061dbe114bc287946b50d08a86778c78c7b38 100644 --- a/tensorflow/contrib/tpu/python/profiler/__init__.py +++ b/tensorflow/contrib/tpu/python/profiler/__init__.py @@ -1,31 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Classes for TPU trace events.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.tpu.profiler.tpu_profiler_analysis_pb2 import * -from tensorflow.contrib.tpu.profiler.trace_events_pb2 import * +from tensorflow.python.tpu.profiler import * # pylint: enable=wildcard-import,unused-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent'] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/tpu/python/tpu/__init__.py b/tensorflow/contrib/tpu/python/tpu/__init__.py index 0dffd7064b19f353aed6afa3ad383564643a4a90..82d4f68c0221013706f70bcf54ae4c97cc7db1d3 100644 --- a/tensorflow/contrib/tpu/python/tpu/__init__.py +++ b/tensorflow/contrib/tpu/python/tpu/__init__.py @@ -1,20 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Ops related to Tensor Processing Units.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py index 6ce96e5bcdbe5777f68eb969be46423b5b3410cb..41aa4d267812cabe775459723df7e01efaa83c93 100644 --- a/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/_tpu_estimator_embedding.py @@ -1,273 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""Tooling for support TPU embedding in TPUEstimator.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - -from tensorflow.contrib.tpu.python.tpu import feature_column as tpu_fc -from tensorflow.contrib.tpu.python.tpu import tpu_embedding -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.feature_column import feature_column as core_fc -from tensorflow.python.feature_column import feature_column_lib as core_fc_lib - -# pylint: disable=protected-access -_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn, - tpu_fc._TPUSharedEmbeddingColumn) -_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn, - core_fc_lib.EmbeddingColumn, - core_fc._SharedEmbeddingColumn) -_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn) - -# pylint: enable=protected-access - - -def get_tpu_embedding_config_from_feature_columns(feature_columns): - """Create configs for TPUEmbedding from a list of feature columns. - - This function will place one embedding tensor per table and the return is - intended to be used as input to TPUEmbedding. - - Args: - feature_columns: a list of supported feature columns. - - Returns: - A pair of dicts, the first maps tables to their config, the second maps - features to tables. - """ - - allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access - - for column in feature_columns: - if not isinstance(column, allowed): - raise TypeError( - 'Unsupported feature column {}. Supported types are {}.'.format( - type(column), allowed)) - - table_to_config = {} - feature_to_table = {} - for column in feature_columns: - feature_name = column.get_feature_key_name() - table_name = 'tbl_{}'.format(column.get_embedding_var_name()) - if feature_name in feature_to_table: - raise ValueError( - 'Feature column {} is used with multiple embeddings and this is ' - 'not supported.'.format(feature_name)) - feature_to_table[feature_name] = table_name - vocabulary_size, dimension = column.get_embedding_table_size() - table_to_config[table_name] = tpu_embedding.TableConfig( - vocabulary_size=vocabulary_size, - dimension=dimension, - initializer=column.get_initializer(), - combiner=column.get_combiner()) - - return table_to_config, feature_to_table - - -def _get_tpu_embedding_optimization_parameters(embedding_config_spec): - """Get tpu_embedding._OptimizationParameters from EmbeddingConfigSpec.""" - if embedding_config_spec.optimizer_type == 'adagrad': - return tpu_embedding.AdagradParameters( - embedding_config_spec.learning_rate, - embedding_config_spec.adagrad_initial_accumulator, - embedding_config_spec.use_gradient_accumulation) - elif embedding_config_spec.optimizer_type == 'sgd': - return tpu_embedding.StochasticGradientDescentParameters( - embedding_config_spec.learning_rate, - embedding_config_spec.use_gradient_accumulattion) - elif embedding_config_spec.optimizer_type == 'adam': - return tpu_embedding.AdamParameters( - embedding_config_spec.learning_rate, - embedding_config_spec.adam_parameters.beta1, - embedding_config_spec.adam_parameters.beta2, - embedding_config_spec.adam_parameters.epsilon, - use_gradient_accumulation=embedding_config_spec - .use_gradient_accumulation) - else: - raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') - - -AdamParameters = collections.namedtuple('AdamParameters', - ['beta1', 'beta2', 'epsilon']) - - -# TODO(shizhiw): Improve the API to support more optimizer parameters in API. -class EmbeddingConfigSpec( - collections.namedtuple('EmbeddingConfigSpec', [ - 'feature_columns', 'learning_rate', 'optimizer_type', - 'adagrad_initial_accumulator', 'clipping_limit', - 'use_gradient_accumulation', 'adam_parameters' - ])): - """Class to keep track of embedding config specification.""" - - def __new__(cls, - feature_columns, - learning_rate, - optimizer_type='adagrad', - adagrad_initial_accumulator=None, - clipping_limit=None, - use_gradient_accumulation=False, - adam_parameters=None): - """Creates an EmbeddingConfigSpec instance. - - Args: - feature_columns: All `FeatureColumn`s used by model. - learning_rate: embedding optimizer learning rate. - optimizer_type: (String) Name of the optimizer for embedding gradients - updates. Must be either 'adagrad' ( `tf.train.AdagradOptimizer`, default - value), 'sgd' (`tf.train.GradientDescentOptimizer`), or 'adam' - (`tf.contrib.opt.LazyAdamOptimizer`) for lazy Adam. This optimizer will - be applied to all embedding variables specified by `feature_columns`. - adagrad_initial_accumulator: Initial accumulator for Adagrad. Used when - optimizer_type is 'adagrad'. Default is `0.1`. - clipping_limit: (Optional) Clipping limit (absolute value). - use_gradient_accumulation: (Experimental) Whether to accumulate the - gradients across TPU embedding mini-batches. Gradient accumulation does - not affect SGD and therefore this is applicable only for Adagrad. - adam_parameters: AdamParameters. Used when optimizer_type is 'adam'. - Default is 0.9 for beta1, 0.999 for beta2 and 1e-8 for epsilon. - - Returns: - An EmbeddingConfigSpec instance. - - Raises: - ValueError: If the feature_columns are not specified. - TypeError: If the feature columns are not of ths correct type (one of - _SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR - _EMBEDDING_COLUMN_CLASSES). - ValueError: If use_gradient_accumulation is True for SGD. - ValueError: If `optimizer_type` is not one of "adagrad" or "sgd" or - "adam". - """ - if not feature_columns: - raise ValueError('`feature_columns` cannot be `None` or empty.') - - # It is unknown at this moment, whether the TPUEstimator is running in CPU - # or TPU mode. So allow non-TPU embedding columns also. - supported_classes = tuple( - list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) + - list(_EMBEDDING_COLUMN_CLASSES)) - - for column in feature_columns: - if not isinstance(column, supported_classes): - raise TypeError( - 'All feature columns must be supported types in {}. Got {}'.format( - supported_classes, type(column))) - - if optimizer_type == 'adagrad': - if adagrad_initial_accumulator is None: - adagrad_initial_accumulator = 0.1 - if adagrad_initial_accumulator <= 0: - raise ValueError('Adagrad initial_accumulator must be positive') - elif optimizer_type == 'sgd': - if use_gradient_accumulation: - raise ValueError('Gradient accumulation makes sense for Adagrad only.') - elif optimizer_type == 'adam': - if adam_parameters is None: - adam_parameters = AdamParameters(0.9, 0.999, 1e-8) - if adam_parameters.beta1 < 0. or adam_parameters.beta1 >= 1.: - raise ValueError('beta1 must be between 0. and 1; got {}.'.format( - adam_parameters.beta1)) - if adam_parameters.beta2 < 0. or adam_parameters.beta2 >= 1.: - raise ValueError('beta2 must be between 0. and 1; got {}.'.format( - adam_parameters.beta2)) - if adam_parameters.epsilon <= 0.: - raise ValueError('epsilon must be positive; got {}.'.format( - adam_parameters.epsilon)) - else: - raise ValueError('optimizer_type must be adagrad or sgd or adam for now.') - - return super(EmbeddingConfigSpec, cls).__new__( - cls, - feature_columns=feature_columns, - learning_rate=learning_rate, - optimizer_type=optimizer_type, - adagrad_initial_accumulator=adagrad_initial_accumulator, - clipping_limit=clipping_limit, - use_gradient_accumulation=use_gradient_accumulation, - adam_parameters=adam_parameters) - - -class EmbeddingConfig(object): - """This is the internal immutable object for embedding config. - - `_EmbeddingConfig` is responsible to _translate_ user provided - `EmbeddingConfigSpec` to internal data structures, mostly constructor - arguments of `TPUEmbedding`. - """ - - def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size, - num_hosts, num_cores, master): - self._embedding_config_spec = embedding_config_spec - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._num_hosts = num_hosts - self._num_cores = num_cores - self._master = master - - self._table_to_config_dict, self._feature_to_table_dict = ( - get_tpu_embedding_config_from_feature_columns( - embedding_config_spec.feature_columns)) - self._optimization_parameters = _get_tpu_embedding_optimization_parameters( - self._embedding_config_spec) - self._mode_to_tpu_embedding_dict = {} - - def has_embedding_tables(self): - return bool(self._table_to_config_dict) - - def _create_tpu_embedding(self, mode): - """Create tpu_embedding.TPUEmbedding based on mode.""" - if mode == model_fn_lib.ModeKeys.TRAIN: - batch_size = self._train_batch_size - else: - batch_size = self._eval_batch_size - - if mode == model_fn_lib.ModeKeys.TRAIN: - tpu_embedding_mode = tpu_embedding.TRAINING - elif (mode == model_fn_lib.ModeKeys.EVAL or - mode == model_fn_lib.ModeKeys.PREDICT): - tpu_embedding_mode = tpu_embedding.INFERENCE - else: - raise ValueError('Mode {} is not supported.'.format(mode)) - - tpu_embedding_ = tpu_embedding.TPUEmbedding( - self._table_to_config_dict, - self._feature_to_table_dict, - batch_size, - tpu_embedding_mode, - self._master, - self._optimization_parameters, - ) - return tpu_embedding_ - - def get_tpu_embedding(self, mode): - if mode not in self._mode_to_tpu_embedding_dict: - self._mode_to_tpu_embedding_dict[mode] = ( - self._create_tpu_embedding(mode)) - return self._mode_to_tpu_embedding_dict[mode] - - -def split_inputs(ctx, features, labels): - """Splits the dense and sparse tensors inside the features and labels.""" - sparse_features = collections.OrderedDict() - if ctx.embedding_config: - tpu_embedding_ = ctx.embedding_config.tpu_embedding - for feature_key in tpu_embedding_.feature_to_table_dict: - sparse_features[feature_key] = features.pop(feature_key) - - return features, labels, sparse_features +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu._tpu_estimator_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index 1b09ce173a64ba3f93ec019c8fd65dc4710f0fcf..5eb8034e47474873ccef0b6123f2becd0668738c 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -1,212 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Hook for asynchronous checkpointing. - -This hook dispatches checkpoint writing operations in a separate thread to -allow execution to continue on the main thread. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import threading -import time - -from tensorflow.core.util.event_pb2 import SessionLog -from tensorflow.python.framework import meta_graph -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import training_util -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training.summary_io import SummaryWriterCache - - -class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): - """Saves checkpoints every N steps or seconds.""" - - def __init__(self, - checkpoint_dir, - save_secs=None, - save_steps=None, - saver=None, - checkpoint_basename="model.ckpt", - scaffold=None, - listeners=None): - """Initializes a `CheckpointSaverHook`. - - Args: - checkpoint_dir: `str`, base directory for the checkpoint files. - save_secs: `int`, save every N secs. - save_steps: `int`, save every N steps. - saver: `Saver` object, used for saving. - checkpoint_basename: `str`, base name for the checkpoint files. - scaffold: `Scaffold`, use to get saver object. - listeners: List of `CheckpointSaverListener` subclass instances. Used for - callbacks that run immediately before or after this hook saves the - checkpoint. - - Raises: - ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of `saver` or `scaffold` should be set. - """ - logging.info("Create AsyncCheckpointSaverHook.") - if saver is not None and scaffold is not None: - raise ValueError("You cannot provide both saver and scaffold.") - self._saver = saver - self._save_thread = None - self._write_graph_thread = None - self._checkpoint_dir = checkpoint_dir - self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) - self._scaffold = scaffold - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=save_secs, every_steps=save_steps) - self._listeners = listeners or [] - self._steps_per_run = 1 - self._summary_writer = None - self._global_step_tensor = None - - self._last_checkpoint_step = None - - def _set_steps_per_run(self, steps_per_run): - self._steps_per_run = steps_per_run - - def begin(self): - self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) - self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use CheckpointSaverHook.") - for l in self._listeners: - l.begin() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - - # We do write graph and saver_def at the first call of before_run. - # We cannot do this in begin, since we let other hooks to change graph and - # add variables in begin. Graph is finalized after all begin calls. - def _write_graph_fn(self): - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, "graph.pbtxt") - self._write_graph_thread = threading.Thread(target=_write_graph_fn, - args=[self]) - self._write_graph_thread.start() - - saver_def = self._get_saver().saver_def if self._get_saver() else None - graph = ops.get_default_graph() - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) - self._summary_writer.add_graph(graph) - self._summary_writer.add_meta_graph(meta_graph_def) - # The checkpoint saved here is the state at step "global_step". - self._save(session, global_step) - self._timer.update_last_triggered_step(global_step) - - def before_run(self, run_context): # pylint: disable=unused-argument - return SessionRunArgs(self._global_step_tensor) - - def after_run(self, run_context, run_values): - global_step = run_context.session.run(self._global_step_tensor) - if self._timer.should_trigger_for_step(global_step): - self._timer.update_last_triggered_step(global_step) - logging.info("Triggering checkpoint. %s", global_step) - if self._save(run_context.session, global_step): - run_context.request_stop() - - def end(self, session): - if self._save_thread: - logging.info("Waiting for any pending checkpoints to finish.") - self._save_thread.join() - if self._write_graph_thread: - logging.info("Waiting for any pending write_graph to finish.") - self._write_graph_thread.join() - - last_step = session.run(self._global_step_tensor) - - if self._last_checkpoint_step != last_step: - self._save(session, last_step, asynchronous=False) - - for l in self._listeners: - l.end(session, last_step) - - def _save(self, session, step, asynchronous=True): - """Saves the latest checkpoint, returns should_stop.""" - - # Skip saving on step 0 - if step == 0: - return - - def _save_fn(): - """Run the saver process.""" - logging.info("Saving checkpoints for %d into %s.", step, self._save_path) - - start_time = time.time() - for l in self._listeners: - l.before_save(session, step) - - self._get_saver().save(session, self._save_path, global_step=step) - self._summary_writer.add_session_log( - SessionLog( - status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), - step) - - for l in self._listeners: - l.after_save(session, step) - - end_time = time.time() - logging.info("Checkpoint actual writing time: (%.3f sec)", - end_time - start_time) - logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - - if not asynchronous: - self._last_checkpoint_step = step - _save_fn() - return - - if self._save_thread is not None: - self._save_thread.join(timeout=0.1) - if self._save_thread.is_alive(): - logging.info("Saver thread still in progress, skipping checkpoint.") - return - - self._last_checkpoint_step = step - self._save_thread = threading.Thread(target=_save_fn) - self._save_thread.start() - - def _get_saver(self): - if self._saver is not None: - return self._saver - elif self._scaffold is not None: - return self._scaffold.saver - - # Get saver from the SAVERS collection if present. - collection_key = ops.GraphKeys.SAVERS - savers = ops.get_collection(collection_key) - if not savers: - raise RuntimeError( - "No items in collection {}. Please add a saver to the collection " - "or provide a saver or scaffold.".format(collection_key)) - elif len(savers) > 1: - raise RuntimeError( - "More than one item in collection {}. " - "Please indicate which one to use by passing it to the constructor." - .format(collection_key)) - - self._saver = savers[0] - return savers[0] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.async_checkpoint import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/bfloat16.py b/tensorflow/contrib/tpu/python/tpu/bfloat16.py index fa74f651aa63c72d14eb78c8af479263810e9b7d..f3d392a8daec2a80f974d90051324a02be002afd 100644 --- a/tensorflow/contrib/tpu/python/tpu/bfloat16.py +++ b/tensorflow/contrib/tpu/python/tpu/bfloat16.py @@ -1,77 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper context for running models with bfloat16.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import tf_contextlib - - -def _get_custom_getter(): - """Returns a custom getter that this class's methods must be called under. - - All methods of this class must be called under a variable scope that was - passed this custom getter. Example: - - ```python - network = ConvNetBuilder(...) - with tf.variable_scope('cg', custom_getter=network.get_custom_getter()): - network.conv(...) - # Call more methods of network here - ``` - - Currently, this custom getter only does anything if self.use_tf_layers is - True. In that case, it causes variables to be stored as dtype - self.variable_type, then casted to the requested dtype, instead of directly - storing the variable as the requested dtype. - """ - - def inner_custom_getter(getter, *args, **kwargs): - """Custom getter that forces variables to have type self.variable_type.""" - cast_to_bfloat16 = False - requested_dtype = kwargs['dtype'] - if requested_dtype == dtypes.bfloat16: - # Only change the variable dtype if doing so does not decrease variable - # precision. - kwargs['dtype'] = dtypes.float32 - cast_to_bfloat16 = True - var = getter(*args, **kwargs) - # This if statement is needed to guard the cast, because batch norm - # assigns directly to the return value of this custom getter. The cast - # makes the return value not a variable so it cannot be assigned. Batch - # norm variables are always in fp32 so this if statement is never - # triggered for them. - if cast_to_bfloat16: - var = math_ops.cast(var, dtypes.bfloat16) - return var - - return inner_custom_getter - - -@tf_contextlib.contextmanager -def bfloat16_scope(): - """Scope class for bfloat16 variables so that the model uses custom getter. - - This enables variables to be read as bfloat16 type when using get_variable. - """ - with variable_scope.variable_scope( - '', custom_getter=_get_custom_getter()) as varscope: - yield varscope +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.bfloat16 import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index bc0cd41d210ac6f8de1b20ebf744ee1e1dd04137..c20aac7e36aa31c5a9d88ca6fe02a8703f9ed5a3 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -1,191 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of Cloud TPU helper functions for data loading.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.experimental.ops import batching -from tensorflow.python.data.experimental.ops import interleave_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops import readers -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import functional_ops - - -def _TextLineDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) - return dataset - - -def _TFRecordDataset(filename): - buffer_size = 8 * 1024 * 1024 # 8 MiB per file - dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) - return dataset - - -_FILETYPE_MAP = { - 'tfrecord': _TFRecordDataset, - 'textline': _TextLineDataset, - 'text': _TextLineDataset, -} - - -def StreamingFilesDataset(files, - filetype=None, - file_reader_job=None, - worker_job=None, - num_epochs=None, - filename_shuffle_buffer_size=None, - num_parallel_reads=None, - batch_transfer_size=None, - sloppy=None): - """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). - - Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read - files local to your GCE VM. In order to train using files stored on your local - VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset - helper to generate a dataset to feed your Cloud TPU with files from your GCE - VM. - - The resulting dataset may return an OutOfRangeError if there are no files - found as a result of the fileglob expansion. - - Note: StreamingFilesDataset assumes that the session is using a - TPUClusterResolver and has therefore a worker and a coordinator job. File - loading will be done on the coordinator job. - - Args: - files: A string glob to match files, or a `tf.data.Dataset` generating file - names. - filetype: A string (one of 'tfrecord', or 'textline') or a single-argument - TensorFlow function that when given a filename returns a dataset. - file_reader_job: An optional string that corresponds to the job that should - perform the file reads. - worker_job: An optional string that corresponds to the job that should - process the tensors (i.e. your GPU or TPU worker). - num_epochs: The number of epochs through the training set that should be - generated. By default, it will repeat infinitely. - filename_shuffle_buffer_size: An optional integer whose value controls the - shuffling of the file names. If you would like to read from the files in - the same order, set to 0 or False. - num_parallel_reads: An optional integer controlling the number of files to - read from concurrently. (Set to 1 for no parallelism.) - batch_transfer_size: An optional integer controlling the batching used to - amortize the remote function invocation overhead. Set to a very large - number to increase throughput. Set to a very small number to reduce memory - consumption. Set to False to skip batching. - sloppy: (Optional.) If `False`, read input data while maintaining a - deterministic order. (This may have significant performance impacts.) - sloppy defaults to: True. - Returns: - A `tf.data.Dataset` with an infinite stream of elements generated by a - parallel interleaving of the set of files matched (or generated) by `files` - with a type is the output of the dataset specified by `filetype`. - - Raises: - ValueError: if any argument is not of the expected type. - """ - if filetype is None: - filetype = 'tfrecord' - - if isinstance(filetype, str): - if filetype not in _FILETYPE_MAP: - raise ValueError('Unexpected filetype: %s' % filetype) - reader_fn = _FILETYPE_MAP[filetype] - elif callable(filetype): - reader_fn = filetype - else: - raise ValueError('filetype should be a string or a callable') - - file_reader_job = file_reader_job or 'coordinator' - - worker_job = worker_job or 'worker' - - if filename_shuffle_buffer_size is None: - filename_shuffle_buffer_size = 4096 - - num_parallel_reads = num_parallel_reads or 8 - - if batch_transfer_size is None: - batch_transfer_size = 256 - - if sloppy is None: - sloppy = True - - with ops.device('/job:%s' % file_reader_job): - if isinstance(files, str): - source_dataset = dataset_ops.Dataset.list_files(files) - elif isinstance(files, dataset_ops.DatasetV2): - source_dataset = files - else: - raise ValueError('files was not a string or a dataset: %s' % files) - - if filename_shuffle_buffer_size: - source_dataset = source_dataset.shuffle( - buffer_size=filename_shuffle_buffer_size) - - source_dataset = source_dataset.apply( - interleave_ops.parallel_interleave( - reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) - - source_dataset = source_dataset.repeat(num_epochs) - - if batch_transfer_size: - source_dataset = source_dataset.batch(batch_transfer_size) - - source_dataset = source_dataset.prefetch(1) - - source_iterator = dataset_ops.make_one_shot_iterator(source_dataset) - source_handle = source_iterator.string_handle() - - @function.Defun(dtypes.string) - def LoadingFunc(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, source_dataset.output_types, source_dataset.output_shapes) - return remote_iterator.get_next() - - def MapFn(unused_input): - if isinstance(source_dataset.output_types, dtypes.DType): - output_types = [source_dataset.output_types] - elif isinstance(source_dataset.output_types, (list, tuple)): - output_types = source_dataset.output_types - else: - raise ValueError('source dataset has invalid output types') - remote_calls = functional_ops.remote_call( - args=[source_handle], - Tout=output_types, - f=LoadingFunc, - target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) - if len(remote_calls) == 1: - return remote_calls[0] - else: - return remote_calls - - with ops.device('/job:%s' % worker_job): - output_dataset = dataset_ops.Dataset.range(2).repeat().map( - MapFn, num_parallel_calls=4 if sloppy else None) - output_dataset = output_dataset.prefetch(1) - - if batch_transfer_size: - # Undo the batching used during the transfer. - output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) - - return output_dataset +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.datasets import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 3313dc749c2c7606101b2dc96614df2d052dfed1..05dffef3a1efdae2ad7306ca5ad3bc7a9eac04cf 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -1,313 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.python.tpu.topology import Topology - - -SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]] - - -def _compute_task_and_cores_to_replicas(core_assignment, topology): - """Computes a nested dict which maps task and logical core to replicas.""" - task_and_cores_to_replicas = {} - for replica in xrange(core_assignment.shape[0]): - for logical_core in xrange(core_assignment.shape[1]): - coordinates = core_assignment[replica, logical_core, :] - task_id = topology.task_ordinal_at_coordinates(coordinates) - if task_id not in task_and_cores_to_replicas: - task_and_cores_to_replicas[task_id] = {} - if logical_core not in task_and_cores_to_replicas[task_id]: - task_and_cores_to_replicas[task_id][logical_core] = set() - - task_and_cores_to_replicas[task_id][logical_core].add(replica) - - task_to_sorted_replica_id = {} - - for task, core_to_replicas in task_and_cores_to_replicas.items(): - core_to_sorted_replicas = {} - for core, replicas in core_to_replicas.items(): - core_to_sorted_replicas[core] = sorted(replicas) - - task_to_sorted_replica_id[task] = core_to_sorted_replicas - return task_to_sorted_replica_id - - -class DeviceAssignment(object): - """Mapping from logical cores in a computation to the physical TPU topology. - - Prefer to use the `device_assignment()` helper to construct a - `DeviceAssignment`; it is easier if less flexible than constructing a - `DeviceAssignment` directly. - """ - - def __init__(self, topology, core_assignment): - """Constructs a `DeviceAssignment` object. - - Args: - topology: A `Topology` object that describes the physical TPU topology. - core_assignment: A logical to physical core mapping, represented as a - rank 3 numpy array. See the description of the `core_assignment` - property for more details. - - Raises: - ValueError: If `topology` is not `Topology` object. - ValueError: If `core_assignment` is not a rank 3 numpy array. - """ - if not isinstance(topology, Topology): - raise ValueError("topology must be a Topology object, got {}".format( - type(topology))) - core_assignment = np.asarray(core_assignment, dtype=np.int32) - - self._topology = topology - - if core_assignment.ndim != 3: - raise ValueError("core_assignment must be a rank 3 numpy array, " - "got shape {}".format(core_assignment.shape)) - - self._num_replicas = core_assignment.shape[0] - self._num_cores_per_replica = core_assignment.shape[1] - - if core_assignment.shape[-1] != topology.mesh_rank: - raise ValueError( - "minor dimension of core_assignment must have size equal to topology " - "rank ({}), got shape {}".format(topology.mesh_rank, - core_assignment.shape)) - - self._core_assignment = core_assignment - self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas( - self._core_assignment, topology) - - @property - def topology(self): - """A `Topology` that describes the TPU topology.""" - return self._topology - - @property - def num_cores_per_replica(self): - """The number of cores per replica.""" - return self._num_cores_per_replica - - @property - def num_replicas(self): - """The number of replicas of the computation.""" - return self._num_replicas - - @property - def core_assignment(self): - """The logical to physical core mapping. - - Returns: - An integer numpy array of rank 3, with shape - `[num_replicas, num_cores_per_replica, topology_rank]`. Maps - (replica, logical core) pairs to physical topology coordinates. - """ - return self._core_assignment - - def _coordinates(self, replica, logical_core): - """Returns the physical topology coordinates of a logical core.""" - return tuple(self.core_assignment[replica, logical_core, :]) - - def lookup_replicas(self, task_id, logical_core): - """Lookup replica ids by task number and logical core. - - Args: - task_id: TensorFlow task number. - logical_core: An integer, identifying a logical core. - Returns: - A sorted list of the replicas that are attached to that task and - logical_core. - Raises: - ValueError: If no replica exists in the task which contains the logical - core. - """ - try: - return self._task_and_cores_to_replicas[task_id][logical_core] - except KeyError: - raise ValueError( - "Can not find any replica in task: {} contains logical_core: {} ". - format(task_id, logical_core)) - - def tpu_ordinal(self, replica=0, logical_core=0): - """Returns the ordinal of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_ordinal_at_coordinates(coordinates) - - def host_device(self, replica=0, logical_core=0, job=None): - """Returns the CPU device attached to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.cpu_device_name_at_coordinates(coordinates, job=job) - - def tpu_device(self, replica=0, logical_core=0, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - coordinates = self._coordinates(replica, logical_core) - return self._topology.tpu_device_name_at_coordinates(coordinates, job=job) - - -def device_assignment(topology, - computation_shape=None, - computation_stride=None, - num_replicas=1): - """Computes a device_assignment of a computation across a TPU topology. - - Attempts to choose a compact grid of cores for locality. - - Returns a `DeviceAssignment` that describes the cores in the topology assigned - to each core of each replica. - - `computation_shape` and `computation_stride` values should be powers of 2 for - optimal packing. - - Args: - topology: A `Topology` object that describes the TPU cluster topology. - To obtain a TPU topology, evaluate the `Tensor` returned by - `initialize_system` using `Session.run`. Either a serialized - `TopologyProto` or a `Topology` object may be passed. Note: you must - evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here. - computation_shape: A rank 1 int32 numpy array with size equal to the - topology rank, describing the shape of the computation's block of cores. - If None, the `computation_shape` is `[1] * topology_rank`. - computation_stride: A rank 1 int32 numpy array of size `topology_rank`, - describing the inter-core spacing of the `computation_shape` cores in the - TPU topology. If None, the `computation_stride` is `[1] * topology_rank`. - num_replicas: The number of computation replicas to run. The replicas will - be packed into the free spaces of the topology. - - Returns: - A DeviceAssignment object, which describes the mapping between the logical - cores in each computation replica and the physical cores in the TPU - topology. - - Raises: - ValueError: If `topology` is not a valid `Topology` object. - ValueError: If `computation_shape` or `computation_stride` are not 1D int32 - numpy arrays with shape [3] where all values are positive. - ValueError: If computation's replicas cannot fit into the TPU topology. - """ - # Deserialize the Topology proto, if it is a string. - if isinstance(topology, bytes): - topology = Topology(serialized=topology) - - if not isinstance(topology, Topology): - raise ValueError("`topology` is not a Topology object; got {}".format( - type(topology))) - - topology_rank = len(topology.mesh_shape) - mesh_shape = topology.mesh_shape - if computation_shape is None: - computation_shape = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_shape = np.asarray(computation_shape, dtype=np.int32) - - if computation_stride is None: - computation_stride = np.array([1] * topology_rank, dtype=np.int32) - else: - computation_stride = np.asarray(computation_stride, dtype=np.int32) - - if computation_shape.shape != (topology_rank,): - raise ValueError("computation_shape must have shape [{}]; got {}".format( - topology_rank, computation_shape.shape)) - if computation_stride.shape != (topology_rank,): - raise ValueError("computation_stride must have shape [{}]; got {}".format( - topology_rank, computation_stride.shape)) - - if any(computation_shape < 1): - raise ValueError( - "computation_shape must be positive; got computation_shape={}".format( - computation_shape)) - if any(computation_stride < 1): - raise ValueError( - "computation_stride must be positive; got computation_stride={}".format( - computation_stride)) - - # Computes the physical size of one computation instance. - computation_footprint = computation_shape * computation_stride - if any(computation_footprint > mesh_shape): - raise ValueError( - "computation footprint {} does not fit in TPU topology shape {}".format( - computation_footprint, mesh_shape)) - - # Computes how many copies of the computation footprint fit in the mesh. - block_counts = mesh_shape // computation_footprint - - replica_counts = block_counts * computation_stride - max_replicas = np.prod(replica_counts) - if num_replicas > max_replicas: - raise ValueError( - "requested {} replicas but only {} replicas with shape {} and " - "computation_stride {} fit in a TPU mesh of shape {}".format( - num_replicas, max_replicas, computation_shape, computation_stride, - mesh_shape)) - - def ceil_of_ratio(n, m): - return (n + m - 1) // m - - replica_shape = [0] * topology_rank - if num_replicas > 0: - remaining_replicas = num_replicas - remaining_dims = topology_rank - - # Choose dimensions as close to an equal cube as possible, in order of - # increasing dimension size. By visiting dimensions in increasing size, we - # assign the most constrained dimension first, so we won't make infeasible - # choices. - # - # As a secondary sort order, visit the dimensions in reverse order. This - # means we try to use both cores on the same chip in preference to two cores - # on different chips. - for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))): - i = -ni - target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims))) - replica_shape[i] = min(target_size, x) - remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i]) - remaining_dims -= 1 - - assert remaining_replicas == 1 and remaining_dims == 0 - - # Assigns an offset to each replica such that no two replicas overlap. - replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32) - for replica in xrange(num_replicas): - # Chooses a replica number in each axis. - t = replica - pos = [] - for dim in replica_shape[::-1]: - pos.append(t % dim) - t //= dim - replica_pos = np.array(pos[::-1], dtype=np.int32) - - # Determines where that replica starts in each axis. - outer = replica_pos // computation_stride - inner = replica_pos % computation_stride - replica_offsets[replica, :] = outer * computation_footprint + inner - - # Computes a complete logical core -> physical core mapping for each replica. - indices = [ - np.arange(0, computation_shape[i] * computation_stride[i], - computation_stride[i]) for i in xrange(topology_rank) - ] - indices = np.concatenate( - [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")], - axis=-1) - indices = indices.reshape((-1, topology_rank)) - assignment = indices + replica_offsets[:, np.newaxis, :] - return DeviceAssignment(topology, core_assignment=assignment) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.device_assignment import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py index 52e1ea42370d653d1de7c12eee4b456ec7ce921c..1b1328b4075d9a737e40693c13e33e0b7c1fbedf 100644 --- a/tensorflow/contrib/tpu/python/tpu/error_handling.py +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -1,132 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""ErrorRendezvous handler for collecting errors from multiple threads.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import sys -import threading -import time - -import six - -from tensorflow.python.framework import errors -from tensorflow.python.platform import tf_logging as logging - -_UNINTERESTING_ERRORS = (errors.CancelledError,) - - -class ErrorRendezvous(object): - """Resolve errors from multiple threads during TPU execution. - - TPU errors can occur on the infeed or outfeed threads as well as the main - training thread. - - Depending on which thread "wins" and receives the session error first, we may - end up showing users a confusing and non-actionable error message (session - cancelled) instead of a root cause (e.g. a bad filename). - - The rendezvous object provides a location to capture these errors until all - threads terminate. At that point we can choose the most informative error - to report. - """ - - def __init__(self, num_sources): - # string -> (message, traceback) - self._errors = {} - self._num_sources = num_sources - self._session_cancel_timer = None - - def record_error(self, source, exc_info, session=None): - """Report an exception from the given source. - - If a session is passed, a timer will be registered to close it after a few - seconds. This is necessary to ensure the main training loop does not hang - if an infeed/oufeed error occurs. We sleep a few seconds to allow a more - interesting error from another thread to propagate. - - Args: - source: string, source of the error - exc_info: Output from `sys.exc_info` (type, value, traceback) - session: Session to close after delay. - """ - _, value, _ = exc_info - self._errors[source] = exc_info - logging.info('Error recorded from %s: %s', source, value) - - if session is not None and self._session_cancel_timer is None: - - def _cancel_session(): - time.sleep(5) - try: - session.close() - except: # pylint: disable=bare-except - pass - - self._session_cancel_timer = threading.Thread(target=_cancel_session,) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - - def record_done(self, source): - """Mark execution source `source` as done. - - If an error was originally reported from `source` it is left intact. - - Args: - source: `str`, source being recorded - """ - logging.info('%s marked as finished', source) - if source not in self._errors: - self._errors[source] = None - - @contextlib.contextmanager - def catch_errors(self, source, session=None): - """Context manager to report any errors within a block.""" - try: - yield - except Exception: # pylint: disable=broad-except - self.record_error(source, sys.exc_info(), session) - - def raise_errors(self, timeout_sec=0): - """Wait for up to `timeout` seconds for all error sources to finish. - - Preferentially raise "interesting" errors (errors not in the - _UNINTERESTING_ERRORS) set. - - Args: - timeout_sec: Seconds to wait for other error sources. - """ - for _ in range(timeout_sec): - if len(self._errors) == self._num_sources: - break - time.sleep(1) - - kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] - - # First check for any interesting errors, then fall back on the session - # cancelled errors etc. - for k, (typ, value, traceback) in kept_errors: - if isinstance(value, _UNINTERESTING_ERRORS): - continue - else: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) - - for k, (typ, value, traceback) in kept_errors: - logging.warn('Reraising captured error') - six.reraise(typ, value, traceback) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.error_handling import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py index 8edf131bc24fd003806263570b63ee8514c49896..ded75e975b10c4265370af260bf804687c9caebc 100644 --- a/tensorflow/contrib/tpu/python/tpu/feature_column.py +++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py @@ -1,429 +1,30 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU Feature Column Library.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.feature_column import feature_column as fc -from tensorflow.python.feature_column import feature_column_lib as fc_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import variable_scope -# pylint: disable=protected-access - - -_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' -_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, - fc._VocabularyFileCategoricalColumn, - fc._VocabularyListCategoricalColumn, - fc._WeightedCategoricalColumn, - fc_lib.IdentityCategoricalColumn, - fc_lib.VocabularyFileCategoricalColumn, - fc_lib.VocabularyListCategoricalColumn, - fc_lib.WeightedCategoricalColumn) - - -def embedding_column(categorical_column, - dimension, - combiner='mean', - initializer=None): - """TPU embedding_column for `tf.feature_column.embedding_column`. - - Note that the interface for TPU embedding_column is different from the non-TPU - version. The following args available for the non-TPU version are NOT - supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. - - Args: - categorical_column: A categorical_column returned from - categorical_column_with_identity, weighted_categorical_column, - categorical_column_with_vocabulary_list or - categorical_column_with_vocabulary_file. - dimension: An integer specifying dimension of the embedding, must be > 0. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. For more information, see - `tf.feature_column.embedding_column`. - initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. - - Returns: - A _TPUEmbeddingColumn. - - Raises: - ValueError: if `dimension` not > 0. - ValueError: if `initializer` is specified but not callable. - """ - if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): - raise TypeError( - 'categorical_column for tpu ' - ' embedding_column must be type %s, got %s.' % (' or '.join([ - cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS - ]), type(categorical_column))) - if (dimension is None) or (dimension < 1): - raise ValueError('Invalid dimension {}.'.format(dimension)) - - if (initializer is not None) and (not callable(initializer)): - raise ValueError('initializer must be callable if specified. ' - 'Embedding of column_name: {}'.format( - categorical_column.name)) - if initializer is None: - initializer = init_ops.truncated_normal_initializer( - mean=0.0, stddev=1 / math.sqrt(dimension)) - - embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access - - def _creator(weight_collections, scope): - embedding_column_layer = fc._EmbeddingColumnLayer( - embedding_shape=embedding_shape, - initializer=initializer, - weight_collections=weight_collections, - trainable=True, - name='embedding_column_layer') - return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable - - column = _TPUEmbeddingColumn( - categorical_column=categorical_column, - dimension=dimension, - combiner=combiner, - layer_creator=_creator, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - # For Embedding column, the initializer is hidden inside the creator Fn, which - # is not accessiable later. So, we attach it to a speicial field. Also note - # that non-TPU Embedding column and non-TPU shared Embedding column handle the - # initializer differently. See shared_embedding_columns for details. - column._tpu_initializer = initializer - return column - - -def shared_embedding_columns(categorical_columns, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None): - """List of dense columns that convert from sparse, categorical input.""" - for categorical_column in categorical_columns: - if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): - raise TypeError( - 'categorical_column for tpu ' - ' shared_embedding_columns must be type %s, got %s.' % (' or '.join([ - cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS - ]), type(categorical_column))) - columns = fc_lib.shared_embedding_columns( - categorical_columns, - dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - - # Use the initializer and shared_embedding_collection_name to create TPU - # version - initializer = columns[0].initializer - shared_embedding_collection_name = columns[0].shared_embedding_collection_name - tpu_columns = [] - - # Create the state (_SharedEmbeddingColumnLayer) here. - for categorical_column in categorical_columns: - column = _TPUSharedEmbeddingColumn( - categorical_column=categorical_column, - dimension=dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True) - tpu_columns.append(column) - - return tpu_columns - - -class _TPUBaseEmbeddingColumn(object): - """Base class for TPU Embedding Column.""" - - def __init__(self, categorical_column): - self._tpu_categorical_column = categorical_column - - def get_combiner(self): - """Returns the embedding combiner.""" - raise NotImplementedError('not implemented') - - def get_embedding_table_size(self): - """Returns the embedding table size, tuple of vocab size and dimension.""" - raise NotImplementedError('not implemented') - - def get_feature_key_name(self): - """Returns the feature key name in the features dict.""" - raise NotImplementedError('not impl') - - def get_weight_key_name(self): - """Return the key name for weights.""" - raise NotImplementedError('not impl') - - def get_embedding_var_name(self): - """Returns the embedding variable name. - - Feature key name and embedding variable name are usually one-to-one mapping. - But for shared embedding columns, it is many-to-one mapping. - """ - raise NotImplementedError('not impl') - - def get_initializer(self): - """Returns the initializer.""" - raise NotImplementedError('not impl') - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - raise NotImplementedError('not impl') - - -class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): - """Core Embedding Column.""" - - def __new__(cls, - categorical_column, - dimension, - combiner='mean', - layer_creator=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable - # are not supported on TPU. They are solely for matching the signature of - # __new__ of parent class fc._EmbeddingColumn. - return fc._EmbeddingColumn.__new__( - cls, - categorical_column, - dimension, - combiner=combiner, - layer_creator=layer_creator, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable) - - def __init__(self, - categorical_column, - dimension, - combiner='mean', - layer_creator=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - _TPUBaseEmbeddingColumn.__init__(self, categorical_column) - self._key = None - - def get_combiner(self): - return self.combiner - - def get_embedding_table_size(self): - """Returns num_ids and width.""" - return (self.categorical_column._num_buckets, self.dimension) - - def get_feature_key_name(self): - """get_feature_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.categorical_column.name - return self.categorical_column.name - - def get_weight_key_name(self): - """get_weight_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.weight_feature_key - return None - - def get_embedding_var_name(self): - """get_embedding_var_name.""" - return self.categorical_column.name - - def get_initializer(self): - return self._tpu_initializer - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - if isinstance( - self.categorical_column, - ( - fc._WeightedCategoricalColumn, # pylint: disable=protected-access - fc_lib.WeightedCategoricalColumn)): - return True - return False - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - if tpu.under_tpu_inference_context(): - def host_computation(): - return fc._EmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - return tpu.outside_compilation(host_computation) - - if _is_running_on_cpu(): - return fc._EmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - - # TPU mode - # Get the embeddings from the LazyBuilder. - tensor = inputs.get(self.get_feature_key_name()) - - # Add to collection for _create_tpu_embedding_variables_and_ops - _record_variable_scope_and_name(self.get_embedding_var_name(), - 'embedding_weights') - - return tensor - - -class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, - fc._SharedEmbeddingColumn): - """Core Shared Embedding Column.""" - - def __new__(cls, - categorical_column, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - return fc._SharedEmbeddingColumn.__new__( - cls, - categorical_column, - dimension, - combiner=combiner, - initializer=initializer, - shared_embedding_collection_name=shared_embedding_collection_name, - ckpt_to_load_from=ckpt_to_load_from, - tensor_name_in_ckpt=tensor_name_in_ckpt, - max_norm=max_norm, - trainable=trainable) - - def __init__(self, - categorical_column, - dimension, - combiner='mean', - initializer=None, - shared_embedding_collection_name=None, - ckpt_to_load_from=None, - tensor_name_in_ckpt=None, - max_norm=None, - trainable=True): - - _TPUBaseEmbeddingColumn.__init__(self, categorical_column) - self._key = None - - def get_combiner(self): - return self.combiner - - def get_embedding_table_size(self): - """Returns num_ids and width.""" - return (self.categorical_column._num_buckets, self.dimension) - - def get_feature_key_name(self): - """get_feature_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.categorical_column.name - return self.categorical_column.name - - def get_weight_key_name(self): - """get_weight_key_name.""" - if self.is_categorical_column_weighted(): - return self.categorical_column.weight_feature_key - return None - - def get_embedding_var_name(self): - """get_embedding_var_name.""" - return self.shared_embedding_collection_name - - def get_initializer(self): - return self.initializer - - def is_categorical_column_weighted(self): - """Check if the categorical column of the embedding column is weighted.""" - if isinstance( - self.categorical_column, - ( - fc._WeightedCategoricalColumn, # pylint: disable=protected-access - fc_lib.WeightedCategoricalColumn)): - return True - return False - - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): - if tpu.under_tpu_inference_context(): - def host_computation(): - return fc._SharedEmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - return tpu.outside_compilation(host_computation) - - if _is_running_on_cpu(): - return fc._SharedEmbeddingColumn._get_dense_tensor( - self, inputs, weight_collections, trainable) - - # TPU mode - # Get the embeddings from the LazyBuilder. - tensor = inputs.get(self.get_feature_key_name()) - - # Add to collection for _create_tpu_embedding_variables_and_ops - _record_variable_scope_and_name( - self.get_embedding_var_name(), - 'embedding_weights', - is_shared_embedding=True) - return tensor - - -def _record_variable_scope_and_name(embedding_var_name, - embedding_var_name_in_fc, - is_shared_embedding=False): - """Add embedding variable name and scope to collection.""" - g = ops.get_default_graph() - collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) - if not collection: - collection.append({}) - - var_def_dict = collection[0] - - captured_scope = None - - if is_shared_embedding and (embedding_var_name in var_def_dict): - if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: - raise ValueError( - 'For embedding var name {}, the shared embedding name is different, ' - 'got {}; expected {}'.format(embedding_var_name, - embedding_var_name_in_fc, - var_def_dict[embedding_var_name][1])) - else: - # scope contains var_scope_name. - captured_scope = variable_scope.get_variable_scope() - var_def_dict[embedding_var_name] = (captured_scope, - embedding_var_name_in_fc) - - -def _is_running_on_cpu(): - """Returns True if the current context is CPU model.""" - return tpu_function.get_tpu_context().number_of_shards is None +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.feature_column import * +# used by tests +from tensorflow.python.tpu.feature_column import _is_running_on_cpu +from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name +from tensorflow.python.tpu.feature_column import _TPU_FC_TO_SCOPE +from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUEmbeddingColumn +from tensorflow.python.tpu.feature_column import _TPUSharedEmbeddingColumn +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py index 24c85156e53a9b770f811c4cf3b903eab6553c76..9a5759221ed9660200cc213df69961db56f8d490 100644 --- a/tensorflow/contrib/tpu/python/tpu/functional.py +++ b/tensorflow/contrib/tpu/python/tpu/functional.py @@ -1,39 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= -"""Functional operations.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import platform - -from tensorflow.contrib.tpu.python.tpu import gen_functional_ops - - -TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access - - -if platform.system() != "Windows": - # pylint: disable=wildcard-import,unused-import,g-import-not-at-top - from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import * - - from tensorflow.contrib.util import loader - from tensorflow.python.platform import resource_loader - # pylint: enable=wildcard-import,unused-import,g-import-not-at-top - - _tpu_partitioned_call_op = loader.load_op_library( - resource_loader.get_path_to_datafile("../ops/_functional_ops.so") - ) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.functional import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 37fe9af8c4b154a2e20a957f6ca5d97df3d413be..6ad4e45e9625f191bb4c01f70b434dc2c4fba638 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -55,8 +55,6 @@ import numpy as np import six from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables from tensorflow.contrib.tpu.python.tpu import tpu @@ -64,6 +62,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops @@ -94,6 +93,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(b/114775106): temporary shim to optionally initialize the TPU @@ -2172,7 +2172,10 @@ Output shape: %(output_shape)s # pylint: enable=bad-continuation -@experimental +@deprecated( + '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. ' + 'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy' +) def tpu_model(model, strategy=None): """Copy `model` along with weights to the TPU. diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 5cb2ca6478a1d7589cd2aa2d52c82306b3fd11f4..ed8f9525c9b91208d39805654b01837abdbf3a77 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -1,438 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the 'License'); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Operations for handling session logging and shutdown notifications.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - -import time -from google.protobuf import text_format - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.util import event_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - -_WATCHDOG = None - - -class CoordinatorShutdownException(Exception): - """Raised when the coordinator needs to shutdown.""" - pass - - -def _clone_session(session, graph=None): - return session_lib.Session( - target=session.sess_str, - config=session._config, # pylint: disable=protected-access - graph=graph if graph else session.graph) - - -def _make_heartbeat_op(session, device, request_ph): - """Return a heartbeat op or None if heartbeats are not supported by device.""" - try: - # Test if we can connect in a isolated graph + session - with ops.Graph().as_default(): - with _clone_session(session) as temp_session: - with ops.device(device): - heartbeat_op = tpu_ops.worker_heartbeat('') - options = config_pb2.RunOptions(timeout_in_ms=5000) - temp_session.run(heartbeat_op, options=options) - except errors.InvalidArgumentError as _: - logging.warning('Error running heartbeat on %s', device) - return None - except errors.DeadlineExceededError as _: - logging.warning('Timeout connecting to %s when testing heartbeat', device) - return None - - # If we successfully connected and pinged the worker, go ahead and construct - # the operation. - with ops.device(device): - return tpu_ops.worker_heartbeat(request_ph) - - -class WorkerHeartbeatManager(object): - """Manages the status/heartbeat monitor for a set of workers.""" - - def __init__(self, session, devices, heartbeat_ops, request_placeholder): - """Construct a new WorkerHeartbeatManager. - - (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) - - Args: - session: `tf.Session`, session to use for heartbeat operations. - devices: `list[string]` Set of devices to connect to. - heartbeat_ops: `list[tf.Operation]` Heartbeat operations. - request_placeholder: `tf.Placeholder[String]` Placeholder used to specify - the WorkerHeartbeatRequest protocol buffer. - """ - self._session = session - self._devices = devices - self._ops = heartbeat_ops - self._request_placeholder = request_placeholder - - @staticmethod - def from_devices(session, devices): - """Construct a heartbeat manager for the given devices.""" - if not devices: - logging.error('Trying to create heartbeat manager with no devices?') - - logging.info('Creating heartbeat manager for %s', devices) - request_placeholder = array_ops.placeholder( - name='worker_heartbeat_request', dtype=dtypes.string) - - heartbeat_ops = [] - kept_devices = [] - for device in devices: - heartbeat_op = _make_heartbeat_op(session, device, request_placeholder) - if heartbeat_op is not None: - kept_devices.append(device) - heartbeat_ops.append(heartbeat_op) - else: - logging.warning('Heartbeat support not available for %s', device) - - return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops, - request_placeholder) - - def num_workers(self): - return len(self._devices) - - def configure(self, message): - """Configure heartbeat manager for all devices. - - Args: - message: `event_pb2.WorkerHeartbeatRequest` - Returns: `None` - """ - logging.info('Configuring worker heartbeat: %s', - text_format.MessageToString(message)) - self._session.run(self._ops, - {self._request_placeholder: message.SerializeToString()}) - - def ping(self, request=None, timeout_in_ms=5000): - """Ping all workers, returning the parsed status results.""" - if request is None: - request = event_pb2.WorkerHeartbeatRequest() - - options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) - results = self._session.run( - self._ops, - feed_dict={self._request_placeholder: request.SerializeToString()}, - options=options) - parsed_results = [ - event_pb2.WorkerHeartbeatResponse.FromString(res_pb) - for res_pb in results - ] - logging.debug('Ping results: %s', parsed_results) - return parsed_results - - def lame_workers(self): - """Ping all workers, returning manager containing lame workers (or None).""" - ping_results = self.ping() - lame_workers = [] - - for ping_response, device, op in zip(ping_results, self._devices, - self._ops): - if ping_response.health_status != event_pb2.OK: - lame_workers.append((device, op)) - - if not lame_workers: - return None - - bad_devices, bad_ops = zip(*lame_workers) - return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, - self._request_placeholder) - - def __repr__(self): - return 'HeartbeatManager(%s)' % ','.join(self._devices) - - def shutdown(self, timeout_ms=10000): - """Shutdown all workers after `shutdown_timeout_secs`.""" - logging.info('Shutting down %s.', self) - req = event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms), - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR) - self.configure(req) - - # Wait for workers to shutdown. This isn't strictly required - # but it avoids triggering multiple checkpoints with the same lame worker. - logging.info('Waiting %dms for worker shutdown.', timeout_ms) - time.sleep(timeout_ms / 1000) - - -def all_worker_devices(session): - """Return a list of devices for each worker in the system.""" - devices = session.list_devices() - return [ - device.name - for device in devices - if ':CPU:' in device.name and 'coordinator' not in device.name - ] - - -class WatchdogManager(threading.Thread): - """Configures worker watchdog timer and handles periodic pings. - - Usage: - # Ping workers every minute, shutting down workers if they haven't received - # a ping after 1 hour. - watchdog_manager = WatchdogManager( - ping_interval=60, shutdown_timeout=3600 - ) - - # Use as a context manager, resetting watchdog on context exit: - with watchdog_manager: - session.run(...) - - # Or setup globally; watchdog will remain active until program exit. - watchdog_manager.configure_and_run() - """ - - def __init__(self, - session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Initialize a watchdog manager. - - Args: - session: Session connected to worker devices. A cloned session and graph - will be created for managing worker pings. - devices: Set of devices to monitor. If none, all workers will be - monitored. - ping_interval: Time, in seconds, between watchdog pings. - shutdown_timeout: Time, in seconds, before watchdog timeout. - """ - threading.Thread.__init__(self) - self.ping_interval = ping_interval - self.shutdown_timeout = shutdown_timeout - self.daemon = True - self._config = session._config # pylint: disable=protected-access - self._target = session.sess_str - self._running = False - self._devices = devices - - self._graph = None - self._session = None - self._worker_manager = None - - def _reset_manager(self): - """Reset the graph, session and worker manager.""" - self._graph = ops.Graph() - self._session = session_lib.Session( - target=self._target, - graph=self._graph, - config=self._config, - ) - - if self._devices is None: - self._devices = all_worker_devices(self._session) - - with self._graph.as_default(): - self._worker_manager = WorkerHeartbeatManager.from_devices( - self._session, self._devices) - - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig( - timeout_ms=self.shutdown_timeout * 1000,), - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) - - def configure_and_run(self): - logging.info( - 'Enabling watchdog timer with %d second timeout ' - 'and %d second ping interval.', self.shutdown_timeout, - self.ping_interval) - self._reset_manager() - self._running = True - self.start() - - def stop(self): - logging.info('Stopping worker watchdog.') - self._worker_manager.configure( - event_pb2.WorkerHeartbeatRequest( - watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), - shutdown_mode=event_pb2.NOT_CONFIGURED)) - self._running = False - self.join() - - def __enter__(self): - self.configure_and_run() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop() - - def run(self): - # Don't fetch logs or adjust timing: just ping the watchdog. - # - # If we hit an exception, reset our session as it is likely broken. - while self._running: - try: - self._worker_manager.ping(request=None) - time.sleep(self.ping_interval) - except errors.OpError as e: - # Catch any TF errors that occur so we don't stop sending heartbeats - logging.debug('Caught error while sending heartbeat: %s', e) - self._reset_manager() - - -def start_worker_watchdog(session, - devices=None, - ping_interval=60, - shutdown_timeout=3600): - """Start global worker watchdog to shutdown workers on coordinator exit.""" - global _WATCHDOG - if _WATCHDOG is None: - # Ensure we can send a few pings before we timeout! - ping_interval = min(shutdown_timeout / 10., ping_interval) - _WATCHDOG = WatchdogManager(session, devices, ping_interval, - shutdown_timeout) - _WATCHDOG.configure_and_run() - - -class GracefulShutdownHook(session_run_hook.SessionRunHook): - """Session hook that watches for shutdown events. - - If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a - SystemShutdown exception is raised to terminate the main session. If `saver` - is None the `SAVERS` collection will be read to find a saver. - - `on_shutdown_hooks` is an optional list of functions that should be called - after checkpointing. The function is called with (`run_context`, - `all_workers`, `lame_workers`). - - If `heartbeat_group` is not specified, it will default to all CPU workers - in the system. - """ - - def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): - self._saver = saver - self._checkpoint_prefix = checkpoint_prefix - self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] - - # Worker heartbeats are managed independently of the main training graph. - self._graph = ops.Graph() - self._workers = None - self._session = None - self._heartbeat_supported = False - - def after_create_session(self, training_session, coord): # pylint: disable=unused-argument - # N.B. We have to pull the global step here to avoid it being unavailable - # at checkpoint time; the graph has been frozen at that point. - if training_util.get_global_step() is None and self.saver() is not None: - raise ValueError( - 'Saver defined but no global step. Run `get_or_create_global_step()`' - ' in your model definition to allow checkpointing.') - - with self._graph.as_default(): - logging.info('Installing graceful shutdown hook.') - self._session = _clone_session(training_session, self._graph) - self._workers = WorkerHeartbeatManager.from_devices( - self._session, all_worker_devices(self._session)) - self._heartbeat_supported = self._workers.num_workers() > 0 - if self._heartbeat_supported: - self._workers.configure( - event_pb2.WorkerHeartbeatRequest( - shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) - else: - logging.warn( - 'No workers support hearbeats. Failure handling will be disabled.') - - def saver(self): - if self._saver: - return self._saver - - savers = ops.get_collection(ops.GraphKeys.SAVERS) - if not savers: - return None - - if not isinstance(savers, list): - return savers - - if len(savers) > 1: - logging.error( - 'Multiple savers in the SAVERS collection. On-demand checkpointing ' - 'will be disabled. Pass an explicit `saver` to the constructor to ' - 'override this behavior.') - return None - - return savers[0] - - def after_run(self, run_context, run_values): - del run_values - - if not self._heartbeat_supported: - return - - lame_workers = self._workers.lame_workers() - if lame_workers: - logging.info('ShutdownHook: lame workers found: %s', lame_workers) - - if self.saver(): - logging.info('ShutdownHook: saving checkpoint to %s', - self._checkpoint_prefix) - self.saver().save( - run_context.session, - self._checkpoint_prefix, - global_step=training_util.get_global_step(), - write_state=True, - ) - else: - logging.info('ShutdownHook: no Saver defined.') - - for fn in self._on_shutdown_hooks: - fn(run_context, self._workers, lame_workers) - - -class RestartComputation(object): - """Restart the entire computation. - - This hook shuts down all workers and returns control to the top-level by - throwing a CoordinatorShutdownException. - """ - - def __init__(self, timeout_ms=10000): - self.timeout_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - del run_context, lame_workers - all_workers.shutdown(timeout_ms=self.timeout_ms) - - logging.info('Terminating coordinator.') - raise CoordinatorShutdownException() - - -class ShutdownLameWorkers(object): - """Shutdown lamed workers. - - Processing will continue normally (typically by waiting for the down - workers to be restarted). - """ - - def __init__(self, timeout_ms=10000): - self.timeout_in_ms = timeout_ms - - def __call__(self, run_context, all_workers, lame_workers): - lame_workers.shutdown(timeout_ms=self.timeout_in_ms) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.session_support import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py index 2c5ea65182e404ec44b24bcd7d0f412c04f1beb1..73db253fd790f26679fb05bd6e7a5da6a99da1a7 100644 --- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -1,1481 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ======================================================================== -"""A utility to trace tensor values on TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import os.path -import re -import sys - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import control_flow_util -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging - -_TRACER_LOG_PREFIX = ' [>>>TT>>>]' -_DEVICE_TYPE_TPU = 'tpu' -_DEVICE_TYPE_CPU = 'cpu' -_TRACE_MODE_NAN_INF = 'nan-inf' -_TRACE_MODE_PART_TENSOR = 'part-tensor' -_TRACE_MODE_PART_TENSOR_SIZE = 3 -_TRACE_MODE_FULL_TENSOR = 'full-tensor' -_TRACE_MODE_NORM = 'norm' -_TRACE_MODE_MAX_ABS = 'max-abs' -_SUBMODE_BRIEF = 'brief' -_SUBMODE_DETAILED = 'detailed' -_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' -_REASON_UNSAFE_OP = 'not-traced-unsafe-op' -_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' -_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' -_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' -_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' -_REASON_SCALAR_GET_TRACED = 'traced-scalar' -_REASON_TENSOR_GET_TRACED = 'traced-tensor' -_REASON_USER_INCLUDED = 'traced-user-included' -_REASON_USER_EXCLUDED = 'not-traced-user-excluded' -_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' -_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' -_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' -_MARKER_SECTION_END = '!!!!!!! section-end:' -_SECTION_NAME_CONFIG = 'configuration' -_SECTION_NAME_REASON = 'reason' -_SECTION_NAME_OP_LIST = 'op-list' -_SECTION_NAME_TENSOR_LIST = 'tensor-list' -_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map' -_SECTION_NAME_GRAPH = 'graph' -_FIELD_NAME_VERSION = 'version:' -_FIELD_NAME_DEVICE = 'device:' -_FIELD_NAME_TRACE_MODE = 'trace-mode:' -_FIELD_NAME_SUBMODE = 'submode:' -_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' -_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:' -_FIELD_NAME_NUM_HOSTS = 'num-hosts:' -_FIELD_NAME_NUM_OPS = 'number-of-ops:' -_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' -_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:' -_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' -_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' -_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") -_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') -_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') -_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') -_FLAG_NAME_ENABLE = 'enable' -_FLAG_NAME_TRACE_MODE = 'trace_mode' -_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' -_FLAG_NAME_SUBMODE = 'submode' -_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' -_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' -_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' -_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' -_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' -_FLAG_NAME_TRACE_DIR = 'trace_dir' -_FLAG_NAME_REPORT_FILE = 'report_file' -_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' -_FLAG_NAME_OP_RANGE = 'op_range' -_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') -_OUTPUT_STREAM_ESCAPE = 'file://' -_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' -_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' -_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' -_TRACE_FILE_NAME = 'trace.all' -_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' -_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 -_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' -_TENSOR_VALUES_CACHE = 'tensor_values_cache' -_REPLICA_ID_TAG = '#replica-id: ' - -def tensor_tracepoint(tensor, checkpoint_name): - """Adds a checkpoint with the given checkpoint name for the given tensor. - - The tensor will be added to the list of tensors that will be traced by the - tensor tracer. - - Args: - tensor: the tensor object for which the tracing is requested. - checkpoint_name: a string name for the checkpoint. This name has to be a - unique name if used within model comparison. The tensors that have the same - checkpoint identifier is compared in model comparison. - Returns: - The provided tensor. - """ - - tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) - tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, - (tensor, checkpoint_name)) - return tensor - - -def keras_layer_tracepoint(layer, checkpoint_name): - """An interface for adding the tensor outputs of a keras layer. - - Encapsulates tensor_tracepoint. - - Args: - layer: A keras layer. - checkpoint_name: a string name for the checkpoint. This name has to be a - unique name if used within model comparison. The tensors that have the same - checkpoint identifier is compared in model comparison. - - Returns: - The provided layer. - """ - try: - outputs = layer.output - if tensor_util.is_tensor(outputs): - tensor_tracepoint(outputs, '%s' % (checkpoint_name)) - else: - idx = 0 - for output_tensor in outputs: - if tensor_util.is_tensor(outputs): - tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) - idx += 1 - except AttributeError: - pass - except RuntimeError: - pass - return layer - - -def _trace_files_need_precreated(output_dir): - """Return True if trace files must be pre-created by users.""" - - if not output_dir.startswith('/'): - return False - if len(output_dir) < 5: - return False - if output_dir[2] != 'n': - return False - if output_dir[3] != 's': - return False - if output_dir[1] != 'c': - return False - if output_dir[4] != '/': - return False - return True - - -def _get_tensor_values_cache(graph=None): - """Returns the variable that implements tensor-value caching.""" - - graph = graph or ops.get_default_graph() - collection = graph.get_collection(_TENSOR_TRACER_STORAGE) - if len(collection) == 1: - return collection[0] - elif not collection: - raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE) - else: - raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE) - return None - - -def _create_tensor_values_cache(graph, num_tensors): - """Creates a variable as the cache to store intermediate tensor values.""" - - graph = graph or ops.get_default_graph() - # Create in proper graph and base name_scope. - with graph.as_default() as g, g.name_scope(None): - return variable_scope.get_variable( - _TENSOR_VALUES_CACHE, - shape=[num_tensors], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer( - _COMPACT_TRACE_ENTRY_INIT_VALUE), - trainable=False, - use_resource=True, - collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.GLOBAL_VARIABLES]) - - -def _set_fetches(result_tensor, train_op): - """Sets the fetches from the result tensor and training op.""" - - fetches = [] - if result_tensor is not None: - fetches.append(result_tensor) - if train_op is not None: - fetches.append(train_op) - if not fetches: - return None - return fetches - - -class TensorTracer(object): - """A software construct for tracing tensor values in a TF graph on TPU. - - This utility is disabled by default. It can be enabled by setting - the TENSOR_TRACER_FLAGS env variable as: - export TENSOR_TRACER_FLAGS="--enable=1" - If it is enabled, it will trace the output tensor values of - selected Ops in the graph. It has two outputs: (1) the traces and (2) - a report. The traces are dumped to a specified local file on the TPU - host. The report is printed to the log.info of the TPU job. - By passing options via the env variable, users can change: - (1) the trace mode (e.g., detecting NaN/Inf, printing partial or - full tensor values) - (2) which Ops to be traced (via op.name or op.type) - (3) output trace file path. - """ - - @staticmethod - def _match_next_flag(flags, pos): - """Returns the match for the next TensorTracer flag. - - Args: - flags: a string that contains the flags. - pos: where in flags to start the search. - - Returns: - A pair where the first element is the regular-expression - match found and the second element indicates if the match - has a value. - """ - - match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_NO_QUOTE_PAT.match(flags, pos) - if match: - return match, True - match = _FLAG_NO_EQUAL_PAT.match(flags, pos) - if match: - # The flag is found but is not given a value. - return match, False - # The flag is not found. - return None, False - - @staticmethod - def validate_flag_names(): - """Validates if the TensorTrace flags passed are valid.""" - valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, - _FLAG_NAME_USE_COMPACT_TRACE, - _FLAG_NAME_SUBMODE, - _FLAG_NAME_EXCLUDED_OPNAMES, - _FLAG_NAME_EXCLUDED_OPTYPES, - _FLAG_NAME_INCLUDED_OPNAMES, - _FLAG_NAME_INCLUDED_OPTYPES, - _FLAG_NAME_TRACE_DIR, - _FLAG_NAME_REPORT_FILE, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, - _FLAG_NAME_OP_RANGE] - tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return - pos = 0 - while True: - match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos) - if not match: - break - flag_name = match.group(1) - if flag_name not in valid_flag_names: - raise ValueError( - 'The flag name "%s" passed via the environment variable "%s" ' - 'is invalid. Valid flag names are:' - '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) - pos = match.end() - - @staticmethod - def print_flag_values(): - """Prints all TensorTracer flags passed via environment variables.""" - - tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR - result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, - tensor_tracer_flags) - result += 'Individual flag value:\n' - pos = 0 - while True: - match, has_value = TensorTracer._match_next_flag( - tensor_tracer_flags, pos) - if not match: - break - flag_name = match.group(1) - if has_value: - flag_value = match.group(2) - else: - flag_value = None - result += ' %s: %s\n'%(flag_name, flag_value) - pos = match.end() - result += '\n' - return result - - @staticmethod - def get_flag_value(wanted_flag_name): - """Returns the value of a TensorTracer flags. - - Args: - wanted_flag_name: the name the the flag we are looking for. - - Returns: - A pair where the first element indicates if the flag is - found and the second element is the value of the flag. - - Raises: - RuntimeError: If supposedly deadcode is reached. - """ - - tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) - if not tensor_tracer_flags: - return False, None - pos = 0 - while True: - match, has_value = TensorTracer._match_next_flag( - tensor_tracer_flags, pos) - if not match: - return False, None - flag_name = match.group(1) - if has_value: - flag_value = match.group(2) - else: - flag_value = None - if flag_name == wanted_flag_name: - return True, flag_value - pos = match.end() - raise RuntimeError('Should not reach here.') - - @staticmethod - def flag_value_to_re_list(flag_name): - """Converts list of strings to compiled RE.""" - - re_list = [] - found, flag_value = TensorTracer.get_flag_value(flag_name) - if not found or not flag_value: - return re_list - list_of_values = flag_value.split() - for v in list_of_values: - r = re.compile(v) - re_list.append(r) - return re_list - - @staticmethod - def _is_flag_on(flag_name): - """Returns True if the given flag is on.""" - - found, flag_value = TensorTracer.get_flag_value(flag_name) - if not found: - return False - if flag_value is None: - return True - # Depends on the flag value. - flag_value = flag_value.lower() - enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] - return enabled - - @staticmethod - def is_enabled(): - """Returns True if TensorTracer is enabled.""" - - return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE) - - @staticmethod - def use_test_undeclared_outputs_dir(): - """Decides the output directory of the report and trace files. - - Args: - None. - - Returns: - True if the output files should be written to the - test-undeclared-outputs-directory defined via an - env variable. - """ - - return TensorTracer._is_flag_on( - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) - - @staticmethod - def use_compact_trace(): - return TensorTracer._is_flag_on( - _FLAG_NAME_USE_COMPACT_TRACE) - - @staticmethod - def check_device_type(device_type): - """Checks if the given device type is valid.""" - - if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: - raise ValueError('Invalid device_type "%s"'%device_type) - - @staticmethod - def check_trace_mode(trace_mode): - """Checks if the given trace mode is valid.""" - - valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, - _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, - _TRACE_MODE_MAX_ABS] - if trace_mode not in valid_trace_modes: - raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' - 'Valid trace modes are: %s'%(trace_mode, - valid_trace_modes)) - - @staticmethod - def check_submode(submode): - """Checks if the given submode is valid.""" - - if not submode: - return - valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] - if submode not in valid_submodes: - raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' - 'Valid submodes are: %s'%(submode, - valid_submodes)) - - @staticmethod - def unsafe_op(op): - """Returns True if this op is not safe to be traced.""" - - if control_flow_util.IsInCond(op): - return True - # Reasons for not including following op types: - # Assign: cause incorrect result with CPU tracing. - if op.type in ['Assign']: - return True - return False - - @staticmethod - def device_mismatch(device_type, op): - if device_type == _DEVICE_TYPE_TPU: - # pylint: disable=protected-access - return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr - # pylint: enable=protected-access - return False - - @staticmethod - def unsafe_scalar_trace(op): - """Return true if scalar output tensor from Op is not safe to be traced.""" - - # Tracing the following causes cycle in the graph on TPU. - if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', - 'Switch', 'Less', 'ReadVariableOp']: - return True - # Tracing the following will cause casting-issue - # with the norm tracing mode or other compilation issues on CPU. - if op.type in ['VarHandleOp', 'IteratorToStringHandle', - 'IteratorGetNext', 'OneShotIterator', - 'IteratorV2', 'MakeIterator', - 'BatchDatasetV2', 'MapDataset', - 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', - 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: - return True - return False - - @staticmethod - def less_interesting_op(op): - """Returns True if the given Op is not an interesting one to be traced.""" - - found, _ = TensorTracer.get_flag_value( - _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) - if found: - # users force to include all ops. - return False - # Following ops are highly unlikey to cause bugs. - return op.type in ['Const', 'Identity', 'Cast', 'Shape'] - - @staticmethod - def reason(op_idx, details): - """Returns reason why the Op at op_idx is traced or not.""" - - return '%d %s'%(op_idx, details) - - @staticmethod - def topological_sort(g): - """Performs topological sort on the given graph. - - Args: - g: the graph. - - Returns: - A pair where the first element indicates if the topological - sort succeeded (True if there is no cycle found; False if a - cycle is found) and the second element is either the sorted - list of nodes or the cycle of nodes found. - """ - - def visit(op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops): - """Recursively visits all Ops in a graph. - - Args: - op: the current Op being visited. - cycle: a cycle of Ops found. - permanently_marked_ops: the set of Ops that were already visited. - temporarily_marked_ops: the set of Ops that we have visited during - the current descent. - sorted_ops: the list of Ops sorted in topological order. - """ - - if cycle: - return - if op in permanently_marked_ops: - return - if op in temporarily_marked_ops: - cycle = temporarily_marked_ops - return - temporarily_marked_ops.add(op) - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - for consumer_op in out_tensor.consumers(): - visit(consumer_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - # pylint: disable=protected-access - for ctrl_output_op in op._control_outputs: - # pylint: enable=protected-access - visit(ctrl_output_op, cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - temporarily_marked_ops.remove(op) - permanently_marked_ops.add(op) - sorted_ops.insert(0, op) - - graph_cycle = set([]) - sorted_ops = [] - permanently_marked_ops = set([]) - temporarily_marked_ops = set([]) - unsorted_ops = g.get_operations() - for op in unsorted_ops: - visit(op, graph_cycle, permanently_marked_ops, - temporarily_marked_ops, sorted_ops) - if graph_cycle: - return (False, graph_cycle) - else: - assert len(unsorted_ops) == len(sorted_ops) - return (True, sorted_ops) - - @staticmethod - def _make_op_and_tensor_maps(op_list): - """Creates various maps and lists from op_list. - - Args: - op_list: a list of Ops - - Returns: - opname_idx_map: a map from Op's name to its index in op_list. - tensor_list: a list of output tensors of the Ops in op_list. - tensorname_idx_map: a map from output tensor name to its index - in tensor_list. - """ - - opname_idx_map = {} - tensor_list = [] - tensorname_idx_map = {} - for op_id, op in enumerate(op_list): - if op.name in opname_idx_map: - raise ValueError('Duplicated Op name: %s'%op.name) - opname_idx_map[op.name] = op_id - for output_tensor in op.outputs: - if output_tensor.name not in tensorname_idx_map: - tensor_list.append(output_tensor) - tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 - return (opname_idx_map, tensor_list, tensorname_idx_map) - - def __init__(self): - """Initializes a TensorTracer. - - Sets the various member fields from the flags (if given) or the defaults. - """ - self._version = 'use-outside-compilation' - self._device_type = None - TensorTracer.validate_flag_names() - found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) - if not found or not self._trace_mode: - self._trace_mode = _TRACE_MODE_NAN_INF - TensorTracer.check_trace_mode(self._trace_mode) - found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE) - if not found or not self._submode: - self._submode = _SUBMODE_DETAILED - TensorTracer.check_submode(self._submode) - self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE - self._instrument_records = {} - self._set_trace_dir() - self._set_report_file() - self._set_op_range() - self._set_excluded_opnames() - self._set_excluded_optypes() - self._set_included_opnames() - self._set_included_optypes() - self._num_replicas = None - self._num_replicas_per_host = None - self._num_hosts = None - self._replica_id = None - - def _add_replica_id_to_graph(self, result_tensor): - """Adds nodes for computing the replica ID to the graph.""" - - if not self._num_replicas: - self._replica_id = 'unknown' - return result_tensor - - with ops.control_dependencies(None): - # Uses None as dependency to run outside of TPU graph rewrites. - self._replica_id = tpu_ops.tpu_replicated_input( - list(range(self._num_replicas)), - name='tt_replica_id') - use_replica_id = array_ops.identity(self._replica_id).op - with ops.control_dependencies([use_replica_id]): - # Adds a control dependency from the result_tensor to - # the replica_id to ensure that replica_id will be added to the graph. - return array_ops.identity(result_tensor) - - def _set_trace_dir(self): - found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR) - if found and self._trace_dir \ - and TensorTracer.use_test_undeclared_outputs_dir(): - raise ValueError('Cannot not use --%s and --%s at the same time' - %(_FLAG_NAME_TRACE_DIR, - _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) - if TensorTracer.use_test_undeclared_outputs_dir(): - self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - - def _set_report_file(self): - """Sets the path of the output report file.""" - - found, self._report_file_path = TensorTracer.get_flag_value( - _FLAG_NAME_REPORT_FILE) - if found and self._report_file_path \ - and TensorTracer.use_test_undeclared_outputs_dir(): - if os.path.isabs(self._report_file_path): - raise ValueError('If use_test_undeclared_outputs_dir is set,' - 'report_file_path cannot be an absolute path (%s)' - %self._report_file_path) - outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) - self._report_file_path = os.path.join(outputs_dir, - self._report_file_path) - if not self._report_file_path: - self._report_file = None - return - try: - self._report_file = gfile.Open(self._report_file_path, 'w') - except IOError as e: - raise e - - def _close_report_file(self): - if self._report_file: - self._report_file.close() - - def _set_op_range(self): - """Sets the index range of the Ops that we will consider tracing.""" - - found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) - if not found or not op_range: - self._op_range = (-1, -1) # this means including all ops. - return - match = _OP_RANGE_PAT.match(op_range) - if not match: - self._op_range = (-1, -1) # this means including all ops. - return - self._op_range = (int(match.group(1)), int(match.group(2))) - - def _inside_op_range(self, idx): - """Return True if the given index is inside the selected range.""" - - if idx < self._op_range[0]: - return False - return self._op_range[1] < 0 or idx <= self._op_range[1] - - def _set_excluded_opnames(self): - self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPNAMES) - - def _set_excluded_optypes(self): - self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_EXCLUDED_OPTYPES) - - def _set_included_opnames(self): - self._included_opname_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPNAMES) - - def _set_included_optypes(self): - self._included_optype_re_list = TensorTracer.flag_value_to_re_list( - _FLAG_NAME_INCLUDED_OPTYPES) - - def _is_user_included_op(self, op): - for opname_re in self._included_opname_re_list: - if opname_re.match(op.name): - return True - for optype_re in self._included_optype_re_list: - if optype_re.match(op.type): - return True - return False - - def _is_user_excluded_op(self, op): - for opname_re in self._excluded_opname_re_list: - if opname_re.match(op.name): - return True - for optype_re in self._excluded_optype_re_list: - if optype_re.match(op.type): - return True - return False - - def _use_tensor_values_cache(self): - """Returns True if immediate tensors should be first saved to a cache.""" - - if self._trace_mode not in set([_TRACE_MODE_NAN_INF, - _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]): - return False - if self._trace_dir and _trace_files_need_precreated(self._trace_dir): - return True - if TensorTracer.use_compact_trace(): - return True - return False - - def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates): - """Returns an Op that will save the given updates to an entry in the cache.""" - - cache = _get_tensor_values_cache(graph) - indices = constant_op.constant([cache_idx]) - return state_ops.scatter_update(cache, indices, updates).op - - def _write_report(self, content): - """Writes the given content to the report.""" - - line = '%s %s'%(_TRACER_LOG_PREFIX, content) - if self._report_file: - self._report_file.write(line) - else: - logging.info(line) - - def _write_config_section(self): - """Writes the config section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) - self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) - self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) - self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) - self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode)) - self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) - self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST, - self._num_replicas_per_host)) - self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) - - def _write_reason_section(self): - """Writes the reason section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) - for key in sorted(self._instrument_records): - self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) - - def _write_op_list_section(self, op_list): - """Writes the Op-list section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) - for i in range(0, len(op_list)): - op = op_list[i] - line = '%d "%s" %s'%(i, op.name, op.type) - for out_tensor in op.outputs: - if out_tensor.name not in self._tensorname_idx_map: - raise ValueError( - 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) - line += ' %d'%self._tensorname_idx_map[out_tensor.name] - line += '\n' - self._write_report(line) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) - - def _write_tensor_list_section(self, tensor_list, opname_idx_map): - """Writes the tensor-list section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, - _SECTION_NAME_TENSOR_LIST)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) - for i in range(0, len(tensor_list)): - tensor = tensor_list[i] - line = '%d "%s"'%(i, tensor.name) - for consumer_op in tensor.consumers(): - if consumer_op.name not in opname_idx_map: - raise ValueError( - 'consumer_op %s is not in opname_idx_map'%consumer_op.name) - line += ' %d'%opname_idx_map[consumer_op.name] - line += '\n' - self._write_report(line) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, - _SECTION_NAME_TENSOR_LIST)) - - def _write_cache_index_map_section(self): - """Writes the mapping from cache index to tensor index to the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, - _SECTION_NAME_CACHE_INDEX_MAP)) - self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES, - len(self._cache_idx_to_tensor_idx))) - for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)): - tensor_idx = self._cache_idx_to_tensor_idx[cache_idx] - line = '%d %d\n'%(cache_idx, tensor_idx) - self._write_report(line) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, - _SECTION_NAME_CACHE_INDEX_MAP)) - - def _write_graph_section(self, succeed, sorted_or_cycle): - """Writes the graph section of the report.""" - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) - self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, - succeed)) - l = list(sorted_or_cycle) - for i in range(0, len(l)): - self._write_report('%d "%s"\n'%(i, l[i].name)) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) - - def _preprocess_traced_tensor(self, tensor): - """Computes NAN/Norm/Max on TPUs before sending to CPU. - - Args: - tensor: The tensor to be traced. - Returns: - A tensor that should be input to the trace_function. - Raises: - RuntimeError: If the trace mode is invalid. - """ - - def _detect_nan_inf(tensor): - """Trace function for detecting any NaN/Inf in the tensor.""" - - if tensor.dtype.is_floating: - mask = math_ops.reduce_any( - gen_math_ops.logical_or( - gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) - output_tensor = control_flow_ops.cond(mask, - lambda: constant_op.constant(1.0), - lambda: constant_op.constant(0.0)) - else: - output_tensor = constant_op.constant(0.0) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - def _show_norm(tensor): - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = linalg_ops.norm(tensor) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - def _show_max_abs(tensor): - tensor = math_ops.cast(tensor, dtypes.float32) - output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) - zero = constant_op.constant(0, dtypes.float32) - output_tensor = gen_math_ops.maximum(zero, output_tensor) - # The shape has to be 1. Set it if it does not have the information. - output_tensor = array_ops.reshape(output_tensor, [1]) - return output_tensor - - if self._trace_mode == _TRACE_MODE_NAN_INF: - return _detect_nan_inf(tensor) - if self._trace_mode == _TRACE_MODE_PART_TENSOR: - return tensor - if self._trace_mode == _TRACE_MODE_FULL_TENSOR: - return tensor - if self._trace_mode == _TRACE_MODE_NORM: - return _show_norm(tensor) - if self._trace_mode == _TRACE_MODE_MAX_ABS: - return _show_max_abs(tensor) - raise RuntimeError( - 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) - - def _make_tensor_trace_fun(self, tensor_name): - """Makes the tensor tracing function called by outside compilation. - - Args: - tensor_name: name of the tensor being traced. - - Returns: - A function to be passed as the first argument to outside compilation. - - Raises: - RuntimeError: If the trace mode is invalid. - """ - - def _print_tensor(tensor_name, num_elements, tensor, output_tensor): - """Prints a tensor value to a file. - - Args: - tensor_name: name of the tensor being traced. - num_elements: number of elements to print (-1 means print all). - tensor: the tensor needs to be returned. - output_tensor: the tensor needs to be printed. - - Returns: - The same tensor passed via the "tensor" argument. - - Raises: - ValueError: If tensor_name is not already in - self._tensorname_idx_map. - """ - - if self._submode == _SUBMODE_BRIEF: - if tensor_name not in self._tensorname_idx_map: - raise ValueError( - 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) - msg = '%d'%self._tensorname_idx_map[tensor_name] - else: - msg = '"%s"'%tensor_name - - if self._trace_dir: - output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME) - output_stream = _OUTPUT_STREAM_ESCAPE + output_path - else: - output_stream = sys.stderr - print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), - '@', self._replica_id, - '\n', output_tensor, '\n', - summarize=num_elements, - output_stream=output_stream) - with ops.control_dependencies([print_op]): - return array_ops.identity(tensor).op - - - def _show_part_tensor(tensor): - """Trace function for printing part of the tensor.""" - - return _print_tensor(tensor_name, self._part_tensor_size, - tensor, tensor) - - def _show_full_tensor(tensor): - """Trace function for printing the entire tensor.""" - - return _print_tensor(tensor_name, -1, tensor, tensor) - - if self._trace_mode == _TRACE_MODE_PART_TENSOR: - return _show_part_tensor - # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, - # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are - # performed within TPUs and only their results are transferred to CPU. - # Simply, print the full tensor for these trace modes. - if self._trace_mode in [ - _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, - _TRACE_MODE_MAX_ABS - ]: - return _show_full_tensor - - raise RuntimeError('Tensor trace fun for %s is not yet implemented' - %self._trace_mode) - - def _skip_op(self, op_id, op, user_included, user_excluded, - in_exec_path=True): - """Returns True if we should not trace Op.""" - - if user_included: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_USER_INCLUDED) - return False - if user_excluded: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_USER_EXCLUDED) - return True - if not in_exec_path: - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_NOT_EXECUTED) - return True - if not self._inside_op_range(op_id): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_OUTSIDE_OP_RANGE) - return True - if TensorTracer.unsafe_op(op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_UNSAFE_OP) - return True - if TensorTracer.device_mismatch(self._device_type, op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_DEVICE_MISMATCH) - return True - if TensorTracer.less_interesting_op(op): - self._instrument_records[op.name] = TensorTracer.reason( - op_id, _REASON_LESS_INTERESTING_OP) - return True - return False - - def _skip_tensor(self, op_id, out_tensor, user_included, - user_excluded): - """Returns True if we should not trace out_tensor.""" - - # Skips a tensor if the tensor has a non-numeric type. - # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) - # because it also excludes tensors with dtypes, bool, and - # float32_ref, which we actually want to trace. - non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, - dtypes.string]) - if out_tensor.dtype in non_numeric_tensor_types: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_NON_NUMERIC_TENSOR) - return True - - if user_included: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_USER_INCLUDED) - return False - if user_excluded: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_USER_EXCLUDED) - return True - if not out_tensor.get_shape().is_fully_defined(): - # If trace mode is nan-inf, norm or max, then the tensor will be reduced - # to a scalar before the outside compilation call. - if self._trace_mode in [ - _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS - ]: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_TENSOR_GET_TRACED) - return False - else: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_DYNAMIC_SHAPE) - return True - rank = len(out_tensor.shape) - if rank < 1: - # scalar - if TensorTracer.unsafe_scalar_trace(out_tensor.op): - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_UNSAFE_SCALAR) - return True - else: - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_SCALAR_GET_TRACED) - return False - else: - # tensor - self._instrument_records[out_tensor.name] = TensorTracer.reason( - op_id, _REASON_TENSOR_GET_TRACED) - return False - - def _filter_execution_path_operations(self, operations, fetches): - """Returns the set of ops in the execution path to compute given fetches.""" - - # If no fetch provided, then return all operations. - if fetches is None: - return set(operations) - # Convert to list, if a single element is provided. - if not isinstance(fetches, (list, tuple)): - fetches = [fetches] - # If a tensor is given as fetch, convert it to op. - op_fetches = [] - for fetch in fetches: - if isinstance(fetch, ops.Operation): - op_fetches.append(fetch) - elif isinstance(fetch, ops.Tensor): - op_fetches.append(fetch.op) - else: - raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' - %fetch) - - execution_path_operations = set(op_fetches) - traverse_stack = list(op_fetches) - while True: - if not traverse_stack: - break - head_op = traverse_stack.pop() - input_ops = [tensor_input.op for tensor_input in head_op.inputs] - input_ops.extend(head_op.control_inputs) - - for input_op in input_ops: - if input_op not in execution_path_operations: - execution_path_operations.add(input_op) - traverse_stack.append(input_op) - return execution_path_operations - - def _determine_traced_tensors(self, graph, fetches): - """Determines the tensors that will be traced.""" - - self._traced_tensorname_to_cache_idx_map = {} - self._cache_idx_to_tensor_idx = [] - operations = graph.get_operations() - # Filter out the operations that won't be executed. - # if fetches=None, then ops_in_exec_path = set(operations) - ops_in_exec_path = self._filter_execution_path_operations(operations, - fetches) - checkpoint_operations = self._get_checkpoints(graph) - for op_id, op in enumerate(operations): - if checkpoint_operations and op.name not in checkpoint_operations: - continue - user_included = self._is_user_included_op(op) - user_excluded = self._is_user_excluded_op(op) - in_exec_path = op in ops_in_exec_path - if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): - continue - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - if self._skip_tensor(op_id, out_tensor, user_included, - user_excluded): - continue - tensor_name = out_tensor.name - if tensor_name in self._traced_tensorname_to_cache_idx_map: - raise ValueError( - 'Tensor name %s should not be already in ' - 'traced_tensorname_to_cache_idx_map'%tensor_name) - if tensor_name not in self._tensorname_idx_map: - raise ValueError( - 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) - tensor_idx = self._tensorname_idx_map[tensor_name] - cache_idx = len(self._traced_tensorname_to_cache_idx_map) - self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx - self._cache_idx_to_tensor_idx.append(tensor_idx) - if len(self._traced_tensorname_to_cache_idx_map) != len( - self._cache_idx_to_tensor_idx): - raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != ' - 'len(self._cache_idx_to_tensor_idx') - - def _check_trace_files(self): - """Checks if any requirements for trace files are satisfied.""" - - if not self._trace_dir: - # traces will be written to stderr. No need to check trace files. - return - if _trace_files_need_precreated(self._trace_dir): - for replica_id in range(0, self._num_replicas): - trace_file_path = os.path.join( - self._trace_dir, - _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id - if not gfile.Exists(trace_file_path): - raise RuntimeError( - '%s must be pre-created with the ' - 'appropriate properties.'%trace_file_path) - else: - if not gfile.Exists(self._trace_dir): - gfile.MkDir(self._trace_dir) - if not gfile.Exists(self._trace_dir): - raise RuntimeError('Failed to create %s'%self._trace_dir) - - def _pre_tracing(self, graph, fetches): - """Work needs to be done prior to TPU or CPU tracing.""" - - self._check_trace_files() - operations = graph.get_operations() - (opname_idx_map, tensor_list, self._tensorname_idx_map) = ( - TensorTracer._make_op_and_tensor_maps(operations)) - self._write_config_section() - self._write_op_list_section(operations) - self._write_tensor_list_section(tensor_list, opname_idx_map) - self._determine_traced_tensors(graph, fetches) - self._write_cache_index_map_section() - # Does the topological sort before adding any nodes to the graph. - (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) - if self._use_tensor_values_cache(): - _create_tensor_values_cache(graph, - len(self._cache_idx_to_tensor_idx)) - return (operations, succeed, sorted_or_cycle) - - def _post_tracing(self, succeed, sorted_or_cycle): - """Work needs to be done after TPU or CPU tracing.""" - - self._write_reason_section() - self._write_graph_section(succeed, sorted_or_cycle) - self._close_report_file() - - def _get_checkpoints(self, graph): - """Returns the list of Ops that produce the tensors traced with API. - - Args: - graph: the graph of Ops. - - Returns: - A set of operation names which should be traced. - """ - - self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, - _TENSOR_TRACER_CHECKPOINT)) - checkpoint_operations = set() - tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) - for (tensor, checkpoint_name) in tensor_tracer_variables: - self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) - checkpoint_operations.add(tensor.op.name) - self._write_report('%s %s\n'%(_MARKER_SECTION_END, - _TENSOR_TRACER_CHECKPOINT)) - return checkpoint_operations - - def _generate_flush_cache_op(self, graph, start_replica, on_tpu): - """Generates an Op that will flush the cache to file. - - Args: - graph: the graph of Ops - start_replica: the ID of the first replica being flushed by this Op. - on_tpu: if the graph is executed on TPU. - - Returns: - The Op to flush the cache to file. - """ - def _make_flush_fun(replica_id): - """Makes a function for flushing the cache for the given replica.""" - - def _fun(): - """A function that flushes the cache to a file.""" - - def _flush_fun(cache): - """Flushes the cache to a file.""" - - if isinstance(replica_id, str): - replica_id_str = replica_id - else: - replica_id_str = '%d'%replica_id - output_path = os.path.join(self._trace_dir, - _COMPACT_TRACE_FILE_PREFIX) \ - + replica_id_str - output_stream = _OUTPUT_STREAM_ESCAPE + output_path - new_step_line = _REPLICA_ID_TAG + replica_id_str - print_op = logging_ops.print_v2( - new_step_line, '\n', - cache, '\n', - summarize=-1, - output_stream=output_stream) - with ops.control_dependencies([print_op]): - return constant_op.constant(0).op - - cache = _get_tensor_values_cache(graph) - if on_tpu: - flush_op = tpu.outside_compilation(_flush_fun, cache.value()) - else: - flush_op = _flush_fun(cache.value()) - with ops.control_dependencies([flush_op]): - reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, - dtype=cache.dtype, - shape=cache.shape) - assign_op = state_ops.assign(cache, reset_value).op - with ops.control_dependencies([assign_op]): - return flush_op.outputs[0] - - return _fun - - def _f(replica_id): - return _make_flush_fun(replica_id) - def _eq(x): - return math_ops.equal(x, self._replica_id) - def _do_nothing(): - return constant_op.constant(0) - - return control_flow_ops.case({\ - _eq(start_replica): _f(start_replica), \ - _eq(start_replica+1): _f(start_replica+1), \ - _eq(start_replica+2): _f(start_replica+2), \ - _eq(start_replica+3): _f(start_replica+3), \ - _eq(start_replica+4): _f(start_replica+4), \ - _eq(start_replica+5): _f(start_replica+5), \ - _eq(start_replica+6): _f(start_replica+6), \ - _eq(start_replica+7): _f(start_replica+7), \ - }, - default=_do_nothing, - exclusive=True).op - - def _flush_tensor_values_cache(self, graph, result_tensor, train_op, on_tpu): - """Flushes the intermediate tensor values in the graph to the cache. - - Args: - graph: the graph of Ops - result_tensor: a result tensor of evaluating the graph. - train_op: the training op. - on_tpu: if the graph is executed on TPU. - - Returns: - An identical copy of result tensor. - """ - - train_op_list = [] - if train_op is not None: - train_op_list.append(train_op) - with ops.control_dependencies(train_op_list): - flush_cache_op_list = [] - for host in range(self._num_hosts): - start_replica = host * 8 - flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu) - flush_cache_op_list.append(flush_op) - with ops.control_dependencies(flush_cache_op_list): - return array_ops.identity(result_tensor) - - def trace_tpu(self, graph, - result_tensor, - train_op, - num_replicas=None, - num_replicas_per_host=None, - num_hosts=None): - """Traces the tensors generated by TPU Ops in a TF graph. - - Args: - graph: the graph of Ops executed on the TPU. - result_tensor: a result tensor of evaluating the graph. - train_op: the training op. - num_replicas: number of replicas used on the TPU. - num_replicas_per_host: number of replicas per TPU host. - num_hosts: total number of TPU hosts. - - Returns: - A tuple (result_tensor_copy, tracing_ops), where: - result_tensor_copy: an exact copy of result_tensor - tracing_ops: a list of tracing ops. If this list - is non empty, the caller of this function - should pose control dependencies upon these - Ops so that they will be executed when the - graph is evaluated. - - Raises: - RuntimeError: If num_replicas_per_host > 8. - """ - - def _cast_unsupported_dtypes(tensor): - """Casts tensor to a supported type.""" - - if tensor.dtype.__eq__(dtypes.int64): - # outside-compilation doesn't support int64 input yet. - return math_ops.cast(tensor, dtypes.int32) - if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( - dtypes.float16): - # Since host can't handle bf16, convert tensor to f32. - return math_ops.cast(tensor, dtypes.float32) - return tensor - - self._device_type = _DEVICE_TYPE_TPU - self._num_replicas = num_replicas - self._num_replicas_per_host = num_replicas_per_host - self._num_hosts = num_hosts - if self._num_replicas_per_host > 8: - # Checks for the assumption in _generate_flush_cache_op(). - raise RuntimeError( - 'num_replicas_per_host (%d) is ' - 'greater than 8'%self._num_replicas_per_host) - - TensorTracer.check_device_type(self._device_type) - result_tensor_copy = self._add_replica_id_to_graph(result_tensor) - fetches = _set_fetches(result_tensor, train_op) - (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches) - - tracing_ops = [] - for op in operations: - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - tensor_name = out_tensor.name - if tensor_name not in self._traced_tensorname_to_cache_idx_map: - continue - # Create the list of consumers before calling _preprocess_traced_tensor. - # Otherwise, adding control input below, will introduce a cycle in the - # graph. - consumers = out_tensor.consumers() - if not consumers: - continue - processed_out_tensor = self._preprocess_traced_tensor(out_tensor) - processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) - if self._use_tensor_values_cache(): - cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name] - trace_op = self._save_tensor_value_to_cache_op(graph, - cache_idx, - processed_out_tensor) - else: - trace_op = tpu.outside_compilation( - self._make_tensor_trace_fun(tensor_name), processed_out_tensor) - for consumer_op in consumers: - # pylint: disable=protected-access - consumer_op._add_control_input(trace_op) - # pylint: enable=protected-access - if self._use_tensor_values_cache(): - result_tensor_final = self._flush_tensor_values_cache(graph, - result_tensor_copy, - train_op, - on_tpu=True) - else: - result_tensor_final = result_tensor_copy - self._post_tracing(succeed, sorted_or_cycle) - return (result_tensor_final, tracing_ops) - - def _generate_cpu_result(self, result_tensor, train_op, graph): - """Generates the final CPU result.""" - - if self._use_tensor_values_cache(): - result_tensor_final = self._flush_tensor_values_cache(graph, - result_tensor, - train_op, - on_tpu=False) - else: - result_tensor_final = array_ops.identity(result_tensor) - return result_tensor_final - - def trace_cpu(self, graph, result_tensor, train_op): - """Traces the tensors generated by CPU Ops in a TF graph. - - Args: - graph: the graph of Ops executed on the CPU. - result_tensor: a result tensor of evaluating the graph. - train_op: the training op. - - Returns: - A pair (final_result_tensor, tracing_calls) where: - final_result_tensor: an identical copy of result_tensor. - tracing_calls: a map from keys to trace calls. - A key is constructed from an Op's name. - A trace call consists of a function and a tensor ( - the function will be invoked with the tensor). - """ - - if result_tensor is None: - raise ValueError( - 'The result_tensor passed to trace_cpu should not be None') - - self._device_type = _DEVICE_TYPE_CPU - TensorTracer.check_device_type(self._device_type) - self._num_replicas = 1 - self._num_replicas_per_host = 1 - self._num_hosts = 1 - self._replica_id = 0 - fetches = _set_fetches(result_tensor, train_op) - (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches) - - tracing_calls = {} - for op in operations: - for i in range(len(op.outputs)): - out_tensor = op.outputs[i] - tensor_name = out_tensor.name - if tensor_name not in self._traced_tensorname_to_cache_idx_map: - continue - # Create the list of consumers before calling _preprocess_traced_tensor. - # Otherwise, adding control input below, will introduce a cycle in the - # graph. - consumers = out_tensor.consumers() - if not consumers: - continue - processed_out_tensor = self._preprocess_traced_tensor(out_tensor) - if self._use_tensor_values_cache(): - cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name] - trace_op = self._save_tensor_value_to_cache_op(graph, - cache_idx, - processed_out_tensor) - for consumer_op in consumers: - # pylint: disable=protected-access - consumer_op._add_control_input(trace_op) - # pylint: enable=protected-access - else: - trace_fun = self._make_tensor_trace_fun(tensor_name) - trace_call = (trace_fun, [processed_out_tensor]) - trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i) - tracing_calls[trace_call_key] = trace_call - - self._post_tracing(succeed, sorted_or_cycle) - final_result_tensor = self._generate_cpu_result(result_tensor, - train_op, - graph) - return (final_result_tensor, tracing_calls) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tensor_tracer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index 6ae718cc2c9716587849aeee8abcd0a1de82a9ae..5bf805752cf51b0a0f4b7400b18b63aae93cf831 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -1,220 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== -"""Defines the `Topology` class, that describes a TPU fabric topology.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.proto import topology_pb2 - - -def _tpu_device_name(job, task, device): - """Returns the device name for the TPU `device` on `task` of `job`.""" - if job is None: - return "/task:%d/device:TPU:%d" % (task, device) - else: - return "/job:%s/task:%d/device:TPU:%d" % (job, task, device) - - -def _tpu_host_device_name(job, task): - """Returns the device name for the CPU device on `task` of `job`.""" - if job is None: - return "/task:%d/device:CPU:0" % task - else: - return "/job:%s/task:%d/device:CPU:0" % (job, task) - - -class Topology(object): - """Describes a set of TPU devices. - - Represents both the shape of the physical mesh, and the mapping between - TensorFlow TPU devices to physical mesh coordinates. - """ - - def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None): - """Builds a Topology object. - - If `serialized` is not `None`, the topology is parsed from `serialized` and - the other arguments are ignored. Otherwise, the topology is computed from - `mesh_shape` and `device_coordinates`. - - Args: - serialized: A serialized `TopologyProto`, or `None`. If not `None`, the - serialized proto is parsed to discover the topology. - mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`, - the shape of the TPU topology, in number of cores. Ignored if - `serialized` is not `None`. - device_coordinates: A rank 3 numpy array that describes the mapping from - TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored - if `serialized is not `None`. - - Raises: - ValueError: If `serialized` does not describe a well-formed topology. - ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence - of 3 positive integers. - ValueError: If `serialized` is `None` and `device_coordinates` is not a - rank 3 numpy int32 array that describes a valid coordinate mapping. - """ - - self._serialized = serialized - - if serialized: - self._parse_topology(serialized) - else: - self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) - self._device_coordinates = np.asarray(device_coordinates, np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a sequence of 3 positive " - "entries; got {}".format(self._mesh_shape)) - - if (len(self._device_coordinates.shape) != 3 or - self._device_coordinates.shape[2] != len(self._mesh_shape)): - raise ValueError("`device_coordinates` must be a rank 3 int32 array " - "with minor dimension equal to the mesh shape rank") - - self._topology_tasks, self._topology_devices = self._invert_topology() - - def _parse_topology(self, serialized): - """Parses a serialized `TopologyProto` into `self`.""" - proto = topology_pb2.TopologyProto() - proto.ParseFromString(serialized) - - self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) - if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): - raise ValueError("`mesh_shape` must be a vector of size 3 with positive " - "entries; got {}".format(self._mesh_shape)) - - if proto.num_tasks < 0: - raise ValueError("`num_tasks` must be >= 0; got {}".format( - proto.num_tasks)) - if proto.num_tpu_devices_per_task < 0: - raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( - proto.num_tpu_devices_per_task)) - - expected_coordinates_size = ( - proto.num_tasks * proto.num_tpu_devices_per_task * len( - proto.mesh_shape)) - if len(proto.device_coordinates) != expected_coordinates_size: - raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " - "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " - "got shape {}".format(proto.num_tasks, - proto.num_tpu_devices_per_task, - proto.mesh_shape, - len(proto.device_coordinates))) - - coords = np.array(proto.device_coordinates, dtype=np.int32) - if any(coords < 0): - raise ValueError("`device_coordinates` must be >= 0") - coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, - len(proto.mesh_shape))) - self._device_coordinates = coords - - def _invert_topology(self): - """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" - tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) - devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) - for task in xrange(self.device_coordinates.shape[0]): - for device in xrange(self.device_coordinates.shape[1]): - x, y, z = self.device_coordinates[task, device, :] - tasks[x, y, z] = task - devices[x, y, z] = device - return tasks, devices - - @property - def mesh_shape(self): - """A rank 1 int32 array describing the shape of the TPU topology.""" - return self._mesh_shape - - @property - def mesh_rank(self): - """Returns the number of dimensions in the mesh.""" - return len(self._mesh_shape) - - @property - def device_coordinates(self): - """Describes the mapping from TPU devices to topology coordinates. - - Returns: - A rank 3 int32 array with shape `[tasks, devices, axis]`. - `tasks` is the number of tasks in the TPU cluster, `devices` is the number - of TPU devices per task, and `axis` is the number of axes in the TPU - cluster topology. Each entry gives the `axis`-th coordinate in the - topology of a task/device pair. TPU topologies are 3-dimensional, with - dimensions `(x, y, core number)`. - """ - return self._device_coordinates - - def task_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow task number attached to `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow task number that contains the TPU device with those - physical coordinates. - """ - return self._topology_tasks[tuple(device_coordinates)] - - def tpu_device_ordinal_at_coordinates(self, device_coordinates): - """Returns the TensorFlow device number at `device_coordinates`. - - Args: - device_coordinates: An integer sequence describing a device's physical - coordinates in the TPU fabric. - - Returns: - Returns the TensorFlow device number within the task corresponding to - attached to the device with those physical coordinates. - """ - return self._topology_devices[tuple(device_coordinates)] - - def cpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the CPU device attached to a logical core.""" - return _tpu_host_device_name( - job, self._topology_tasks[tuple(device_coordinates)]) - - def tpu_device_name_at_coordinates(self, device_coordinates, job=None): - """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, - self._topology_tasks[tuple(device_coordinates)], - self._topology_devices[tuple(device_coordinates)]) - - @property - def num_tasks(self): - """Returns the number of TensorFlow tasks in the TPU slice.""" - return self._device_coordinates.shape[0] - - @property - def num_tpus_per_task(self): - """Returns the number of TPU devices per task in the TPU slice.""" - return self._device_coordinates.shape[1] - - def serialized(self): - """Returns the serialized form of the topology.""" - if self._serialized is None: - proto = topology_pb2.TopologyProto() - proto.mesh_shape[:] = list(self._mesh_shape) - proto.num_tasks = self._device_coordinates.shape[0] - proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates.extend(list(self._device_coordinates.flatten())) - self._serialized = proto.SerializeToString() - - return self._serialized +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.topology import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index de2bfd49eca50c87dc506d9aa690d49c8da20460..5364b20f231ac7af8adf943c3d5e21921b7a06a9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1,1566 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ====================================== - -"""Library of TPU helper functions.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.proto import dynamic_padding_pb2 as dynamic_padding -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.compat import compat as api_compat -from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import compat -from tensorflow.python.util import nest - - -# Operations that indicate some error in the users graph, e.g. a placeholder -# that's introduced outside of the infeed. -_BLACKLISTED_OPS = set([ - "Placeholder", -]) - -# XLA doesn't currently support reading of intermediate tensors, thus some ops -# are not supported. -_UNSUPPORTED_OPS = set([ - "AudioSummary", - "AudioSummaryV2", - "HistogramSummary", - "ImageSummary", - "MergeSummary", - "Print", - "ScalarSummary", - "TensorSummary", - "TensorSummaryV2", - ]) - -_MAX_WARNING_LINES = 5 - -_TPU_REPLICATE_ATTR = "_tpu_replicate" -_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" -_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" - - -def _tpu_system_device_name(job): - """Returns the device name for the TPU_SYSTEM device of `job`.""" - if job is None: - return "/device:TPU_SYSTEM:0" - else: - return "/job:%s/device:TPU_SYSTEM:0" % job - - -def initialize_system(embedding_config=None, job=None): - """Initializes a distributed TPU system for use with TensorFlow. - - Args: - embedding_config: If not None, a `TPUEmbeddingConfiguration` proto - describing the desired configuration of the hardware embedding lookup - tables. If embedding_config is None, no hardware embeddings can be used. - job: The job (the XXX in TensorFlow device specification /job:XXX) that - contains the TPU devices that will be initialized. If job=None it is - assumed there is only one job in the TensorFlow flock, and an error will - be returned if this assumption does not hold. - Returns: - A serialized `TopologyProto` that describes the TPU system. Note: - the topology must be evaluated using `Session.run` before it can be used. - """ - config_string = ("" if embedding_config is None else - embedding_config.SerializeToString()) - with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu(embedding_config=config_string) - - -def shutdown_system(job=None): - """Shuts down a running a distributed TPU system.""" - with ops.device(_tpu_system_device_name(job)): - shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() - return shutdown_distributed_tpu - - -def core(num): - """Returns the device name for a core in a replicated TPU computation. - - Args: - num: the virtual core number within each replica to which operators should - be assigned. - Returns: - A device name, suitable for passing to `tf.device()`. - """ - return "device:TPU_REPLICATED_CORE:{}".format(num) - - -class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU computation. - - The primary role of `TPUReplicateContext` is to mark operators inside a - tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ - is a unique name. - - We use a `ControlFlowContext` to perform the annotation since it integrates - with Tensorflow constructs like ResourceVariables. For example, if a - `ResourceVariable` is constructed inside a tpu.replicate() block, the - `ResourceVariable` implementation can use - `with ops.control_dependencies(None)` to build the variable's definition - outside the replicated computation. - """ - - def __init__(self, name, num_replicas, pivot): - """Builds a new TPUReplicateContext. - - Args: - name: a unique name for the context, used to populate the `_tpu_replicate` - attribute. - num_replicas: an integer that gives the number of replicas for the - computation. - pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ - super(TPUReplicateContext, self).__init__() - self._num_replicas = num_replicas - self._outer_device_function_stack = None - self._oc_dev_fn_stack = None - self._outside_compilation_cluster = None - self._outside_compilation_counter = 0 - self._in_gradient_colocation = None - self._gradient_colocation_stack = [] - self._host_compute_core = [] - self._name = name - self._name_as_bytes = compat.as_bytes(name) - self._unsupported_ops = [] - self._pivot = pivot - self._replicated_vars = {} - - def get_replicated_var_handle(self, name, vars_): - """Returns a variable handle for replicated TPU variable 'var'. - - This is a method used by an experimental replicated variable implementation - and is not intended as a public API. - - Args: - name: The common name of the variable. - vars_: The replicated TPU variables. - - Returns: - The handle of the TPU replicated input node. - """ - handle = self._replicated_vars.get(name) - if handle is not None: - return handle - - # Builds a TPUReplicatedInput node for the variable, if one does not already - # exist. The TPUReplicatedInput node must belong to the enclosing - # control-flow scope of the TPUReplicateContext. - # TODO(phawkins): consider changing the contract of the TPU encapsulation - # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope - # instead. - - # pylint: disable=protected-access - graph = ops.get_default_graph() - saved_context = graph._get_control_flow_context() - graph._set_control_flow_context(self.outer_context) - handle = tpu_ops.tpu_replicated_input( - [v.handle for v in vars_], name=name + "/handle") - graph._set_control_flow_context(saved_context) - # pylint: enable=protected-access - self._replicated_vars[name] = handle - return handle - - def report_unsupported_operations(self): - if self._unsupported_ops: - op_str = "\n".join([" %s (%s)" % (op.type, op.name) - for op in self._unsupported_ops[:_MAX_WARNING_LINES]]) - logging.warning("%d unsupported operations found: \n%s", - len(self._unsupported_ops), op_str) - if len(self._unsupported_ops) > _MAX_WARNING_LINES: - logging.warning("... and %d more" % - (len(self._unsupported_ops) - _MAX_WARNING_LINES)) - - def EnterGradientColocation(self, op, gradient_uid): - if op is not None: - self._gradient_colocation_stack.append(op) - if not self._outside_compilation_cluster: - try: - outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR) - if self._in_gradient_colocation: - raise NotImplementedError( - "Cannot nest gradient colocation operations outside compilation" - ) - if gradient_uid == "__unsupported__": - raise NotImplementedError( - "No gradient_uid calling gradient within outside_compilation") - # When we take the gradient of an op X in an outside_compilation - # cluster C in a forward computation we would like to put the ops - # corresponding to the gradient of X into a new outside_compilation - # cluster C'. However, if we take the gradient of X twice, the second - # one should get yet another new outside_compilation cluster C''. - # - # The mechanism we adopt is to use a 'root_cluster' which is the - # cluster that X was in before we took gradients, and a 'gradient_uid' - # which is different for every invocation of gradients, and put the - # gradient of X in cluster 'root_cluster.gradient_uid'. - # - # When taking a gradient of a gradient, some ops will be colocated - # with Op in the forward pass (e.g., cluster root_cluster) and some in - # the backward pass (e.g., cluster root_cluster.initial_gradient_uid). - # We need all of the grad-of-grad ops to be in the same cluster to - # avoid cyclic dependencies between clusters. We adopt a heuristic - # that puts any op clustered with root_cluster. in - # root_cluster.gradient_uid, even if xxx was initial_gradient_uid. - self._in_gradient_colocation = op - parts = outside_attr.split(".") - cluster = parts[0] + "." + gradient_uid - self._EnterOutsideCompilationScope(cluster=cluster) - except ValueError: - # The attr was not present: do nothing. - pass - - def ExitGradientColocation(self, op, gradient_uid): - if op is not None: - if not self._gradient_colocation_stack: - raise errors.InternalError( - op.node_def, op, - "Badly nested gradient colocation: empty stack when popping Op " + - op.name) - last_op = self._gradient_colocation_stack.pop() - if op is last_op: - if op is self._in_gradient_colocation: - self._in_gradient_colocation = None - self._ExitOutsideCompilationScope() - else: - raise errors.InternalError( - op.node_def, op, "Badly nested gradient colocation, expected " + - last_op + ", got " + op.name) - - def _EnterOutsideCompilationScope(self, cluster=None): - - class FakeOp(object): - """A helper class to determine the current device. - - Supports only the type and device set/get methods needed to run the - graph's _apply_device_function method. - """ - - def __init__(self): - self._device = "" - - @property - def type(self): - return "FakeOp" - - @property - def device(self): - return self._device - - def _set_device(self, device): - if isinstance(device, pydev.DeviceSpec): - self._device = device.to_string() - else: - self._device = device - - if self._outside_compilation_cluster: - raise NotImplementedError("Cannot nest outside_compilation clusters") - if cluster: - self._outside_compilation_cluster = cluster - else: - self._outside_compilation_cluster = str(self._outside_compilation_counter) - self._outside_compilation_counter += 1 - graph = ops.get_default_graph() - fake_op = FakeOp() - graph._apply_device_functions(fake_op) # pylint: disable=protected-access - device = pydev.DeviceSpec.from_string(fake_op.device) - if (device.device_type == "TPU_REPLICATED_CORE" and - device.device_index is not None): - self._host_compute_core.append(self._outside_compilation_cluster + ":" + - str(device.device_index)) - self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access - graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access - - def _ExitOutsideCompilationScope(self): - if not self._outside_compilation_cluster: - raise NotImplementedError( - "Attempted to exit outside_compilation scope when not in scope") - self._outside_compilation_cluster = None - graph = ops.get_default_graph() - graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access - - def Enter(self): - if not self._outer_device_function_stack: - # Capture the device function stack at the time of first entry - # since that is the stack that will be used outside_compilation. - graph = ops.get_default_graph() - # pylint: disable=protected-access - self._outer_device_function_stack = graph._device_function_stack.copy() - # pylint: enable=protected-access - super(TPUReplicateContext, self).Enter() - - def HostComputeCore(self): - return self._host_compute_core - - def _RemoveExternalControlEdges(self, op): - """Remove any external control dependency on this op.""" - internal_control_inputs = [] - external_control_inputs = [] - for x in op.control_inputs: - # pylint: disable=protected-access - is_internal_op = False - ctxt = x._get_control_flow_context() - while ctxt is not None: - if ctxt == self: - is_internal_op = True - break - ctxt = ctxt._outer_context - if is_internal_op: - internal_control_inputs.append(x) - else: - external_control_inputs.append(x) - # pylint: enable=protected-access - # pylint: disable=protected-access - op._remove_all_control_inputs() - op._add_control_inputs(internal_control_inputs) - # pylint: enable=protected-access - return internal_control_inputs, external_control_inputs - - def AddOp(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_OPS: - logging.error("Operation of type %s (%s) is not supported on the TPU. " - "Execution will fail if this op is used in the graph. " % - (op.type, op.name)) - - if op.type in _UNSUPPORTED_OPS: - self._unsupported_ops.append(op) - - if any(x.dtype._is_ref_dtype for x in op.inputs): - raise NotImplementedError( - "Non-resource Variables are not supported inside TPU computations " - "(operator name: %s)" % op.name) - if _TPU_REPLICATE_ATTR in op.node_def.attr: - raise ValueError("TPU computations cannot be nested") - op._set_attr(_TPU_REPLICATE_ATTR, - attr_value_pb2.AttrValue(s=self._name_as_bytes)) - if self._outside_compilation_cluster: - op._set_attr( - _OUTSIDE_COMPILATION_ATTR, - attr_value_pb2.AttrValue( - s=compat.as_bytes(self._outside_compilation_cluster))) - if self._num_replicas > 1 or not self._outside_compilation_cluster: - # Prevent feeding or fetching anything that is being compiled, - # and any replicated outside_compilation Op. - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) - - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. - (internal_control_inputs, - external_control_inputs) = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not internal_control_inputs: - # pylint: disable=protected-access - op._add_control_input(self.GetControlPivot()) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_control_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() - # pylint: disable=protected-access - op._add_control_inputs(external_control_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - """Add `val` to the current context and its outer context recursively.""" - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - - result = val - self._values.add(val.name) - if self._outer_context: - result = self._outer_context.AddValue(val) - self._values.add(result.name) - - self._external_values[val.name] = result - - return result - - def AddInnerOp(self, op): - self.AddOp(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - @property - def grad_state(self): - # Define the gradient loop state associated with the TPUReplicateContext to - # be None as the TPUReplicateContext does not get nested nor does the - # grad_state outside the TPUReplicateContext affect the graph inside so the - # grad_state should be as if this is the top-level gradient state. - return None - - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - def GetControlPivot(self): - return self._pivot - - -def outside_compilation(computation, *args, **kwargs): - """Builds part of a computation outside any current TPU replicate scope. - - Args: - computation: A Python function that builds the computation to - place on the host. - *args: the positional arguments for the computation. - **kwargs: the keyword arguments for the computation. - - Returns: - The Tensors returned by computation. - """ - args = [] if args is None else args - graph = ops.get_default_graph() - - # If we are in a TPUReplicateContext, signal that we are now - # outside_compilation - initial_context = graph._get_control_flow_context() # pylint: disable=protected-access - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._EnterOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - retval = computation(*args, **kwargs) - - # If we are in a TPUReplicateContext, signal that we are no longer - # outside_compilation - final_context = graph._get_control_flow_context() # pylint: disable=protected-access - if initial_context is not final_context: - raise NotImplementedError( - "Control-flow context cannot be different at start and end of an " - "outside_compilation scope") - context = initial_context - while context: - if isinstance(context, TPUReplicateContext): - context._ExitOutsideCompilationScope() # pylint: disable=protected-access - context = context.outer_context - - return retval - - -def replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None, - maximum_shapes=None): - """Builds a graph operator that runs a replicated TPU computation. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. Each input can be a nested structure - containing values that are convertible to tensors. Note that passing an - N-dimension list of compatible values will result in a N-dimention list of - scalar tensors rather than a single Rank-N tensors. If you need different - behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - maximum_shapes: A nested structure of tf.TensorShape representing the shape - to which the respective component of each input element in each replica - should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a - tf.TensorShape or -1 in a tensor-like object) will be padded to the - maximum size of that dimension over all replicas. Note that if the input - dimension is already static, we won't do padding on it and we require the - maximum_shapes to have the same value or None on that dimension. The - structure of `maximum_shapes` needs to be the same as `inputs[0]`. - Returns: - A list of outputs, indexed by `[replica_num]` each output can be a nested - structure same as what computation() returns with a few exceptions. - - Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - ValueError: If the static `inputs` dimensions don't match with the values - given in `maximum_shapes`. - ValueError: If the structure of inputs per replica does not match - the structure of `maximum_shapes`. - """ - return split_compile_and_replicate( - computation, - inputs, - infeed_queue, - device_assignment, - name, - maximum_shapes=maximum_shapes)[1] - - -def _pad_all_input(inputs, padded_shapes): - """Pad all input tensors given padded_shapes. - - The real shape tensors will be concatenated with the padded original inputs. - - Args: - inputs: The original inputs. - padded_shapes: A list of padded shapes for each input. - - Returns: - The padded inputs and a PaddingMap list which maps the padded input - dimension to the real shape argument index. - """ - input_shape_tensors = [] - for core_idx, inputs_per_core in enumerate(inputs): - for idx, input_tensor in enumerate(inputs_per_core): - if core_idx == 0: - input_shape_tensors.append([]) - input_shape_tensors[idx].append(array_ops.shape(input_tensor)) - - maximum_shapes = [] - for shapes_per_input in input_shape_tensors: - maximum_shapes.append( - math_ops.reduce_max(array_ops.stack(shapes_per_input), axis=0)) - - padded_inputs = [] - real_shapes = [] - padding_maps = [] - for core_idx, inputs_per_core in enumerate(inputs): - padded_inputs.append([]) - real_shapes.append([]) - real_shape_idx = len(inputs_per_core) - 1 - for idx, input_tensor in enumerate(inputs_per_core): - input_shape_tensor = input_shape_tensors[idx][core_idx] - input_shape = input_tensor.get_shape() - padded_shape = padded_shapes[idx] - - # The static shape of inputs should be compatible with the given padded - # shapes. - input_shape.assert_is_compatible_with(padded_shape) - - if input_shape.is_fully_defined(): - # Do nothing if the shape of the whole tensor is already static. - padded_inputs[core_idx].append(input_tensor) - else: - # Only pad the non static shape dimension. - for i, s in enumerate(input_shape): - if s.value is None: - if core_idx == 0: - real_shape_idx += 1 - padding_map = dynamic_padding.PaddingMap() - padding_map.arg_index = idx - padding_map.shape_index = i - padding_map.padding_arg_index = real_shape_idx - padding_maps.append(padding_map) - real_shapes[core_idx].append( - math_ops.cast(input_shape_tensor[i], dtypes.uint32)) - - paddings = [] - for i, s in enumerate(padded_shape): - if input_shape[i].value: - # Don't pad if input shape is already static. - padding = [0, 0] - else: - if s.value: - # Pad to the given maximum value. - padding = [0, s.value - input_shape_tensor[i]] - else: - # If maximum value is not given, then pad to the maximum dimension - # among all the cores. - padding = [0, maximum_shapes[idx][i] - input_shape_tensor[i]] - paddings.append(padding) - - padded_input = array_ops.pad(input_tensor, paddings) - padded_inputs[core_idx].append(padded_input) - - num_replicas = len(padded_inputs) - for i in range(num_replicas): - padded_inputs[i].extend(real_shapes[i]) - - return padded_inputs, padding_maps - - -def split_compile_and_replicate(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None, - use_tpu=True, - maximum_shapes=None): - """Builds graph operators that runs compilation and replicated computation. - - This is a lower level interface than replicate that returns a separate compile - and execute output tensor. In the generated graph the compile op feeds into - the execute op and no additional compilation is incurred when running the - compile op before the execute op. The compile op returns additional - information about the compilation but does not return the compiled program. - - Args: - computation: A Python function that builds the computation to replicate. - inputs: A list of lists of input tensors or `None` (equivalent to - `[[]]`), indexed by `[replica_num][input_num]`. All replicas must - have the same number of inputs. Each input can be a nested structure - containing values that are convertible to tensors. Note that passing an - N-dimension list of compatible values will result in a N-dimention list of - scalar tensors rather than a single Rank-N tensors. If you need different - behavior, convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to computation. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each replica of the computation uses - only one core, and there is either only one replica, or the number of - replicas is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU - backends. Currently, only supports a default placement (computation is - placed on GPU if one is available, and on CPU if not). - maximum_shapes: A nested structure of tf.TensorShape representing the shape - to which the respective component of each input element in each replica - should be padded. Any unknown dimensions (e.g. tf.Dimension(None) in a - tf.TensorShape or -1 in a tensor-like object) will be padded to the - maximum size of that dimension over all replicas. Note that if the input - dimension is already static, we won't do padding on it and we require the - maximum_shapes to have the same value or None on that dimension. The - structure of `maximum_shapes` needs to be the same as `inputs[0]`. - - Returns: - A list of lists with the first list corresponding to the compile op and the - second a list of output tensors, indexed by `[replica_num][output_num]`. - Raises: - ValueError: If all replicas do not have equal numbers of input tensors. - ValueError: If the number of inputs per replica does not match - the number of formal parameters to `computation`. - ValueError: If the static `inputs` dimensions don't match with the values - given in `maximum_shapes`. - ValueError: If the structure of inputs per replica does not match - the structure of `maximum_shapes`. - """ - del name - inputs = [[]] if inputs is None else inputs - - metadata_kwargs = {} - if device_assignment is not None: - # Turn the Numpy array into a flattened list so we can pass it as an - # operator attribute. - metadata_kwargs = { - "topology": - device_assignment.topology.serialized(), - "device_assignment": - device_assignment.core_assignment.flatten().tolist() - } - # TODO(phawkins): remove this case after the forward compatibility window - # expires on 2018-10-5. - if api_compat.forward_compatible(2018, 10, 5): - metadata_kwargs["num_cores_per_replica"] = ( - device_assignment.num_cores_per_replica) - else: - metadata_kwargs["computation_shape"] = [ - device_assignment.num_cores_per_replica - ] - - if ((not isinstance(inputs, list)) or - any(not isinstance(inp, (list, tuple)) for inp in inputs)): - raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") - - num_replicas = len(inputs) - - # No replicas? Nothing to do. - if num_replicas == 0: - return [] - - # Checks all replicas have the same structure. - for i in xrange(1, num_replicas): - nest.assert_same_structure(inputs[0], inputs[i]) - - # Flatten inputs. - flat_inputs = [ - nest.flatten(per_replica_input) for per_replica_input in inputs - ] - # Converts inputs to Tensors. - flat_inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in flat_inputs] - - # Verifies that all replicas have matching numbers and types of inputs - flat_input_types = [x.dtype for x in flat_inputs[0]] - input_arity = len(inputs[0]) - flat_input_arity = len(flat_input_types) - for i in range(num_replicas): - if len(inputs[i]) != input_arity: - raise ValueError("Replicas must have the same number of inputs. " - "Replica 0 had {} inputs, replica {} had {} " - "inputs.".format(input_arity, i, len(inputs[i]))) - - types = [x.dtype for x in flat_inputs[i]] - if types != flat_input_types: - raise ValueError("Replicas must have matching input types. Replica 0 had " - "input types {}, replica {} had input types {}".format( - flat_input_types, i, types)) - - arg_error = xla.check_function_argument_count( - computation, input_arity, infeed_queue) - if arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s, but the computation needs %s" % ( - input_arity, str([i.name for i in inputs[0]]), arg_error)) - else: - raise TypeError( - "Supplied computation cannot be called with the specified inputs. " - "You specified %d inputs: %s and %d additional inputs from infeed," - " but the computation needs %s" % (input_arity, str( - [i.name - for i in inputs[0]]), infeed_queue.number_of_tuple_elements, - arg_error)) - - if maximum_shapes: - if infeed_queue: - raise ValueError( - "Dynamic input shapes are not supported with infeed queues") - - # Make sure maximum_shapes has the same structure as inputs. - nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False) - - # Flatten padded shapes. - flat_maximum_shapes = nest.flatten(maximum_shapes) - flat_maximum_shapes = [ - tensor_shape.TensorShape(s) for s in flat_maximum_shapes - ] - - flat_inputs, padding_maps = _pad_all_input(flat_inputs, flat_maximum_shapes) - - serialized_padding_maps = [] - for padding_map in padding_maps: - serialized_padding_maps.append(padding_map.SerializeToString()) - metadata_kwargs["padding_map"] = serialized_padding_maps - - graph = ops.get_default_graph() - - # Fan-in: Builds a TPUReplicatedInput node for each input. - flat_replicated_inputs = [] - for i in range(0, len(flat_inputs[0])): - replicas = [flat_inputs[replica][i] for replica in xrange(num_replicas)] - flat_replicated_inputs.append( - tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - - cluster_name = graph.unique_name("cluster") - pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") - context = TPUReplicateContext( - name=cluster_name, num_replicas=num_replicas, pivot=pivot) - try: - context.Enter() - - metadata = tpu_ops.tpu_replicate_metadata( - num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) - - with tpu_function.tpu_shard_context( - num_replicas), ops.control_dependencies([metadata]): - - # Add identity ops so even unused inputs are "consumed" by the - # computation. This is to avoid orphaned TPUReplicatedInput nodes. - # TODO(phawkins): consider instead pruning unused TPUReplicatedInput - # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - flat_replicated_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(flat_replicated_inputs) - ] - for i in flat_replicated_inputs: - # pylint: disable=protected-access - # Add an attribute to the identity node so that they could be removed in - # encapsulate TPU computation pass if unused. However we don't remove - # inputs when dynamic padding is enabled. - # TODO(rxsang): Use other ways except argument index in padding_map so - # outside compilation can work with dynamic padding correctly. - if maximum_shapes is None: - i.op._set_attr("_tpu_input_identity", - attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - - # Unflatten the computation inputs to match original input structure. - computation_inputs = nest.pack_sequence_as( - structure=inputs[0], - flat_sequence=flat_replicated_inputs[:flat_input_arity]) - - # If there is an infeed queue, adds the dequeued values to the - # computation's inputs. - if infeed_queue is not None: - infeed_queue.set_number_of_shards(num_replicas) - for t in infeed_queue.generate_dequeue_op(): - computation_inputs.append(t) - - # Only resource variables work inside a TPU computation, so turn on - # resource variables for the computation. - # TODO(phawkins): consider removing this code. It will - # be less confusing to clients if they knowingly choose to use resource - # variables. - # Partitioned variables is not supported (b/112311320). - vscope = variable_scope.get_variable_scope() - saved_use_resource = vscope.use_resource - saved_custom_getter = vscope.custom_getter - - def custom_getter(getter, name, *args, **kwargs): - """Variables on TPU have a few restrictions.""" - partitioner = kwargs["partitioner"] - if partitioner is not None: - kwargs["partitioner"] = None - logging.warning( - "Partitioned variables are not supported on TPU. Got " - "`partitioner` that is {} for variable {}. " - "Setting `partitioner` to `None`." - .format(partitioner, name)) - if saved_custom_getter is None: - return getter(name, *args, **kwargs) - else: - return saved_custom_getter(getter, name, *args, **kwargs) - - vscope.set_use_resource(True) - vscope.set_custom_getter(custom_getter) - - outputs = computation(*computation_inputs) - - vscope.set_use_resource(saved_use_resource) - vscope.set_custom_getter(saved_custom_getter) - - outputs_is_flat = xla.is_flat(outputs) - if outputs_is_flat: - output_tensors, control_deps = _postprocess_flat_outputs(outputs) - else: - output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) - - context.ExitResult(output_tensors) - finally: - context.report_unsupported_operations() - context.Exit() - host_compute_core = context.HostComputeCore() - - if host_compute_core: - attr_value = attr_value_pb2.AttrValue() - attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) - metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access - - with ops.control_dependencies([metadata]): - if use_tpu: - compile_status = tpu_ops.tpu_compilation_result() - op = compile_status.op - attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) - op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access - else: - compile_status = control_flow_ops.no_op(name="compilation_status") - - if not output_tensors: - # Returns a list of NoOps dependent on the replication Op, indexed by - # [replica_num]. - return [ - compile_status, - [ - control_flow_ops.group(control_deps, name="shard_%d" % i) - for i in range(num_replicas) - ] - ] - - # Fan-out: Builds a TPUReplicatedOutput node for each output. - replicated_outputs = [[] for i in xrange(num_replicas)] - for i, t in enumerate(output_tensors): - # Fan-out: Builds a TPUReplicatedOutput node for each output. - ys = tpu_ops.tpu_replicated_output( - t, num_replicas, name="output{}".format(i)) - - # Wraps the outputs in identity operators so the names of any possible - # `fetch` nodes are preserved by the replication rewrite. - with ops.control_dependencies(control_deps): - for replica in xrange(num_replicas): - replicated_outputs[replica].append( - array_ops.identity( - ys[replica], name="output_%d_shard_%d" % (i, replica))) - - if not outputs_is_flat: - replicated_outputs = [ - nest.pack_sequence_as(outputs, replica_outs) - for replica_outs in replicated_outputs - ] - - return [compile_status, replicated_outputs] - - -def _postprocess_flat_outputs(outputs): - """Validates non-flat outputs, add backs device assignments and other attrs. - - Args: - outputs: Output from `computation` inside `tpu.rewrite`. - - Returns: - Tensors and Operations extracted from outputs. - """ - # Following code segment is to preserve legacy behavior. Previously we only - # supported flat outputs and thus for consistency it was nice to convert even - # single element into a tuple. But now that we support arbitrary output - # structure, this is no longer necessary. - # TODO(b/121383831): Migrate all legacy use cases and delete this special - # case. - # If the computation returns `None`, make it an empty tuple. - if outputs is None: - outputs = tuple() - # If the computation only returned one value, makes it a tuple. - if not isinstance(outputs, collections.Sequence): - outputs = (outputs,) - - # Append `no_op` here so that fetching any return value of this function - # will trigger TPUExecute node. - outputs += (control_flow_ops.no_op(),) - try: - with ops.device(core(0)): - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - "convertible to Tensors. Got '%s'" % str(e)) - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU functions must return zero-or more Tensor values followed by " - "zero or more Operations.") - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - new_output_tensors = [] - for t in output_tensors: - with ops.device(t.device if t.device else core(0)): - o = array_ops.identity(t) - # pylint: disable=protected-access - o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - new_output_tensors.append(o) - return new_output_tensors, output_operations - - -def _postprocess_non_flat_outputs(outputs): - """Validates non-flat outputs, add backs device assignments and other attrs. - - Args: - outputs: Output from `computation` inside `tpu.rewrite`. - - Returns: - Tensors extracted from outputs and an empty list because Operations are not - allowed in non-flat outputs.. - """ - - # Flatten output items. - flat_outputs = nest.flatten(outputs) - - # Convert all non-Operation outputs to Tensors. - for i, o in enumerate(flat_outputs): - if isinstance(o, ops.Operation): - raise ValueError( - "tpu.rewrite does not support Operation as return value in non-flat " - "output structure. You can set returned Operations as control " - "dependencies of returned Tensors so Operations are triggered when " - 'Tensors are evaluated. Operation found: "%s"' % o.name) - - try: - o = ops.convert_to_tensor(o) - except Exception as e: - raise ValueError( - "TPU function return values must all either be Operations or " - 'convertible to Tensors. Got error: "%s"' % str(e)) - - # Wraps outputs in Identity ops. Otherwise a replicated input copied - # straight to an output would bypass the replicate(). This would be bad - # because the TPUReplicatedInput/TPUReplicatedOutput operator would not - # be rewritten away, leading to a runtime error. - # TODO(phawkins): extend the rewrite to elide these nodes instead. - with ops.device(core(0)): - o = array_ops.identity(o) - # pylint: disable=protected-access - o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True)) - # pylint: enable=protected-access - flat_outputs[i] = array_ops.identity(o) - - # All flat_outputs are Tensors, and no Operations. - return flat_outputs, [] - - -def split_compile_and_shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` for parallel execution. - - `inputs` must be a list of Tensors or None (equivalent to an empty list), each - of which has a corresponding split axis (from `input_shard_axes`). Each input - is split into `num_shards` pieces along the corresponding axis, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - If `outputs_from_all_shards` is true, the outputs from all shards of - `computation` are concatenated back together along their `output_shards_axes`. - Otherwise, each output is taken from an arbitrary shard. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). Each - input tensor has a corresponding shard axes, given by `input_shard_axes`, - which must have size divisible by `num_shards`. - num_shards: The number of shards. - input_shard_axes: A list of dimensions along which to shard `inputs`, or - `None`. `None` means "shard all inputs along dimension 0". If not `None`, - there must be one dimension per input. - outputs_from_all_shards: Boolean or list of boolean. For each output, if - `True`, outputs from all shards are concatenated along the corresponding - `output_shard_axes` entry. Otherwise, each output is taken - from an arbitrary shard. If the argument is a boolean, the argument's - value is used for each output. - output_shard_axes: A list of dimensions along which to concatenate the - outputs of `computation`, or `None`. `None` means "concatenate all outputs - along dimension 0". If not `None`, there must be one dimension per output. - Ignored if `outputs_from_all_shards` is False. - infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs - of `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A tuple of (compile op, [output tensors]). - Raises: - ValueError: If num_shards <= 0 - ValueError: If len(input_shard_axes) != len(inputs) - ValueError: If len(output_shard_axes) != len(outputs from `computation`) - """ - # TODO(phawkins): consider adding support for broadcasting Tensors passed as - # inputs. - - if num_shards <= 0: - raise ValueError("num_shards must be a positive integer.") - - inputs = [] if inputs is None else inputs - if not isinstance(inputs, list): - raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.") - - # Converts inputs to Tensors. - inputs = [ops.convert_to_tensor(x) for x in inputs] - - if input_shard_axes is None: - input_shard_axes = [0] * len(inputs) - if len(inputs) != len(input_shard_axes): - raise ValueError("Length of input_shard_axes must be equal to the number " - "of inputs.") - - if inputs: - # Splits the `inputs` along the corresponding `input_shard_axes`, giving - # lists with layout [input][shard] - split_inputs = [ - array_ops.split(x, num_shards, axis=axis) - for (axis, x) in zip(input_shard_axes, inputs)] - - # Transposes the input lists to have layout [shard][input] - transposed_inputs = [list(i) for i in zip(*split_inputs)] - else: - transposed_inputs = [[]] * num_shards - - compile_op, outputs = split_compile_and_replicate( - computation, - transposed_inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - # There must be at least one shard since num_shards > 0. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - if isinstance(outputs[0], ops.Operation): - # pylint: enable=indexing-exception - # There were no outputs from the computation and replicate returned a list - # of NoOps with control dependencies on the computation. Return the first - # one so it can be used as a control dependency or fetch node. - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return compile_op, [outputs[0]] - # pylint: enable=indexing-exception - - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - num_outputs = len(outputs[0]) - # pylint: enable=indexing-exception - - if output_shard_axes is None: - output_shard_axes = [0] * num_outputs - if num_outputs != len(output_shard_axes): - raise ValueError("Length of output_shard_axes must be equal to the number " - "of outputs.") - - if isinstance(outputs_from_all_shards, bool): - outputs_from_all_shards = [outputs_from_all_shards] * num_outputs - - if num_outputs != len(outputs_from_all_shards): - raise ValueError("Length of outputs_from_all_shards must be equal to the " - "number of outputs.") - - results = [] - for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, - zip(*outputs)): - if all_shards: - # Concatenate all of the outputs together (use stack for scalars). - shape = x[0].shape - is_scalar = shape is not None and (shape.ndims == 0) - results.append((array_ops.stack(list(x)) if is_scalar - else array_ops.concat(list(x), axis=axis))) - else: - # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. - results.append(x[0]) - - return compile_op, results - - -def shard(computation, - inputs=None, - num_shards=1, - input_shard_axes=None, - outputs_from_all_shards=True, - output_shard_axes=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` for parallel execution. - - `inputs` must be a list of Tensors or None (equivalent to an empty list), each - of which has a corresponding split axis (from `input_shard_axes`). Each input - is split into `num_shards` pieces along the corresponding axis, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - TODO(phawkins): consider adding support for broadcasting Tensors passed - as inputs. - - If `outputs_from_all_shards` is true, the outputs from all shards of - `computation` are concatenated back together along their `output_shards_axes`. - Otherwise, each output is taken from an arbitrary shard. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). Each - input tensor has a corresponding shard axes, given by `input_shard_axes`, - which must have size divisible by `num_shards`. - num_shards: The number of shards. - input_shard_axes: A list of dimensions along which to shard `inputs`, or - `None`. `None` means "shard all inputs along dimension 0". If not `None`, - there must be one dimension per input. - outputs_from_all_shards: Boolean or list of boolean. For each output, if - `True`, outputs from all shards are concatenated along the corresponding - `output_shard_axes` entry. Otherwise, each output is taken - from an arbitrary shard. If the argument is a boolean, the argument's - value is used for each output. - output_shard_axes: A list of dimensions along which to concatenate the - outputs of `computation`, or `None`. `None` means "concatenate all outputs - along dimension 0". If not `None`, there must be one dimension per output. - Ignored if `outputs_from_all_shards` is False. - infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs - of `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If num_shards <= 0 - ValueError: If len(input_shard_axes) != len(inputs) - ValueError: If len(output_shard_axes) != len(outputs from `computation`) - """ - return split_compile_and_shard( - computation, - inputs=inputs, - num_shards=num_shards, - input_shard_axes=input_shard_axes, - outputs_from_all_shards=outputs_from_all_shards, - output_shard_axes=output_shard_axes, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name)[1] - - -def batch_parallel(computation, - inputs=None, - num_shards=1, - infeed_queue=None, - device_assignment=None, - name=None): - """Shards `computation` along the batch dimension for parallel execution. - - Convenience wrapper around shard(). - - `inputs` must be a list of Tensors or None (equivalent to an empty list). - Each input is split into `num_shards` pieces along the 0-th dimension, and - computation is applied to each shard in parallel. - - Tensors are broadcast to all shards if they are lexically captured by - `computation`. e.g., - - x = tf.constant(7) - def computation(): - return x + 3 - ... = shard(computation, ...) - - The outputs from all shards are concatenated back together along their 0-th - dimension. - - Inputs and outputs of the computation must be at least rank-1 Tensors. - - Args: - computation: A Python function that builds a computation to apply to each - shard of the input. - inputs: A list of input tensors or None (equivalent to an empty list). The - 0-th dimension of each Tensor must have size divisible by `num_shards`. - num_shards: The number of shards. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: If not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. Uses a default device assignment if `None`. The - `DeviceAssignment` may be omitted if each shard of the computation uses - only one core, and there is either only one shard, or the number of shards - is equal to the number of cores in the TPU system. - name: (Deprecated) Does nothing. - Returns: - A list of output tensors. - Raises: - ValueError: If `num_shards <= 0` - """ - return shard( - computation, - inputs, - num_shards=num_shards, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - - -def rewrite(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for execution on a TPU system. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. - - `computation` may return a list of operations and tensors. Tensors must - come before operations in the returned list. The return value of - `rewrite` is a list of tensors corresponding to the tensors from the - output of `computation`. - - All `Operation`s constructed during `computation` will be executed when - evaluating any of the returned output tensors, not just the ones returned. - inputs: A list of input tensors or `None` (equivalent to an empty list). - Each input can be a nested structure containing values that are - convertible to tensors. Note that passing an N-dimension list of - compatible values will result in a N-dimention list of scalar tensors - rather than a single Rank-N tensors. If you need different behavior, - convert part of inputs to tensors with `tf.convert_to_tensor`. - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: (Deprecated) Does nothing. - Returns: - Same data structure as if computation(*inputs) is called directly with some - exceptions for correctness. Exceptions include: - 1) None output: a NoOp would be returned which control-depends on - computation. - 2) Single value output: A tuple containing the value would be returned. - 3) Operation-only outputs: a NoOp would be returned which - control-depends on computation. - TODO(b/121383831): Investigate into removing these special cases. - """ - # TODO(b/36647078) remove disable when pylint bug is fixed. - # pylint: disable=indexing-exception - return replicate( - computation, - None if inputs is None else [inputs], - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name)[0] - # pylint: enable=indexing-exception - - # Operations that indicate some error in the user's inference graph. -_BLACKLISTED_INFERENCE_OPS = set([ - "ReadVariableOp", - "AssignVariableOp", - "AssignAddVariableOp", - "AssignSubVariableOp", - "VarHandleOp", - "Variable", - "VariableV2", -]) - - -def under_tpu_inference_context(): - """Check if it is currently under `tpu.rewrite_for_inference()`.""" - graph = ops.get_default_graph() - - context = graph._get_control_flow_context() # pylint: disable=protected-access - while context: - if isinstance(context, _TPUInferenceContext): - return True - context = context.outer_context - - return False - - -class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): - """A `ControlFlowContext` for nodes inside a TPU inference computation. - - The primary role of `TPUReplicateContext` is to sanity check operators inside - a tpu.rewrite_for_inference() computation. - """ - - def __init__(self, name): - super(_TPUInferenceContext, self).__init__() - self._name = name - - def AddOp(self, op): - self._AddOpInternal(op) - - def _AddOpInternal(self, op): - # pylint: disable=protected-access - if op.type in _BLACKLISTED_INFERENCE_OPS: - raise NotImplementedError( - "Operation of type %s (%s) is not supported on the TPU for inference." - " Execution will fail if this op is used in the graph. Make sure your" - " variables are using variable_scope." % (op.type, op.name)) - if self._outer_context: - self._outer_context.AddInnerOp(op) - - def AddValue(self, val): - result = val - if self._outer_context: - result = self._outer_context.AddValue(val) - return result - - def AddInnerOp(self, op): - self._AddOpInternal(op) - - @property - def grad_state(self): - return None - - -@experimental -def validate_inference_rewrite_for_variables(graph): - """Validates whether rewrite_for_inference() 'worked' for variables. - - The rewrite_for_inference() method is supposed to append GuaranteeConstOps - after ReadVariableOps, but this mechanism works only if you are using - tf.get_variable() to create and access variables in your tpu computation. - This validation method can be called immediately after calling - tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added - to the graph. - - Typical usages: - tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) - - tpu.validate_inference_rewrite_for_variables(sess.graph) - - Args: - graph: The graph which needs to be validated. - Raises: - RuntimeError: if validation failed. - """ - if not any(x.type == "GuaranteeConst" for x in graph.get_operations()): - raise RuntimeError( - "No GuaranteeConst ops found in the graph after running " - "tpu.rewrite_for_inference(...). Please check that you are using " - "tf.get_variable() to create and access variables in your tpu " - "computation.") - - -@experimental -def rewrite_for_inference(computation, - inputs=None, - infeed_queue=None, - device_assignment=None, - name=None): - """Rewrites `computation` for inference on a TPU system. - - Other than 'rewriting' the computation to run on a TPU, if using variables - in your computation, it moves the ReadVariableOps outside the TPU - computation, and adds GuaranteeConst ops just after the ReadVariableOps. - This mechanism works only if you are using tf.get_variable() to create and - access variables in your tpu computation. You can validate whether this - worked, by calling validate_inference_rewrite_for_variables() method - immediately after this method to check whether GuaranteeConstOps where - added to the graph. - - Args: - computation: A Python function that builds a computation to apply to the - input. If the function takes n inputs, 'inputs' should be a list of n - tensors. If the function returns m outputs, rewrite will return a list of - m tensors. - inputs: A list of input tensors or `None` (equivalent to an empty list). - infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple - of arguments as inputs to `computation`. - device_assignment: if not `None`, a `DeviceAssignment` describing the - mapping between logical cores in the computation with physical cores in - the TPU topology. May be omitted for a single-core computation, in which - case the core attached to task 0, TPU device 0 is used. - name: The name of the operator. - Returns: - A list of output tensors. - """ - - def guarantee_const_getter(getter, name, *args, **kwargs): - with ops.control_dependencies(None): - return array_ops.guarantee_const( - getter(name, *args, **kwargs), name=name + "/GuaranteeConst") - - def wrapped_computation(*args, **kwargs): - """Execute computation under `_TPUInferenceContext`.""" - context = _TPUInferenceContext( - name=ops.get_default_graph().unique_name("rewrite_for_inference")) - try: - context.Enter() - - vscope = variable_scope.get_variable_scope() - prev_custom_getter = vscope.custom_getter - prev_caching_device = vscope.caching_device - vscope.set_custom_getter(guarantee_const_getter) - vscope.set_caching_device(lambda op: op.device) - - result = computation(*args, **kwargs) - - vscope.set_custom_getter(prev_custom_getter) - vscope.set_caching_device(prev_caching_device) - finally: - context.Exit() - return result - - # pylint: disable=undefined-variable - return rewrite( - wrapped_computation, - inputs=inputs, - infeed_queue=infeed_queue, - device_assignment=device_assignment, - name=name) - # pylint: enable=undefined-variable +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu import * +# used by tests +from tensorflow.python.tpu.tpu import _TPU_REPLICATE_ATTR +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 9f8d14706845baa1ed45c84b2c15d372915a0eb4..c36aaa38c0e4823bfc438773e4aa5b5109794da4 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -1,275 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""A RunConfig subclass with TPU support.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import json -import os - -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.estimator import run_config as run_config_lib -from tensorflow.python.platform import tf_logging as logging - -# pylint: disable=protected-access -_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV -_SERVICE_KEY = run_config_lib._SERVICE_KEY -_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' -# pylint: enable=protected-access - - -class InputPipelineConfig(object): - r"""Please see the definition of these values in TPUConfig.""" - PER_SHARD_V1 = 1 - PER_HOST_V1 = 2 - PER_HOST_V2 = 3 - BROADCAST = 4 - - -class TPUConfig( - collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', - 'num_shards', - 'num_cores_per_replica', - 'per_host_input_for_training', - 'tpu_job_name', - 'initial_infeed_sleep_secs', - 'input_partition_dims', - ])): - r"""TPU related configuration required by `TPUEstimator`. - - Args: - iterations_per_loop: This is the number of train steps running in TPU - system before returning to CPU host for each `Session.run`. This means - global step is increased `iterations_per_loop` times in one `Session.run`. - It is recommended to be set as number of global steps for next checkpoint. - num_shards: (Deprecated, ignored by TPUEstimator). - The number of model replicas in the system. For non-model-parallelism - case, this number equals the total number of TPU cores. For - model-parallelism, the total number of TPU cores equals - num_cores_per_replica * num_shards. - num_cores_per_replica: Defaults to `None`, which disables model parallelism. - An integer which describes the number of TPU cores per model replica. This - is required by model-parallelism which enables partitioning - the model to multiple cores. Currently num_cores_per_replica must be - 1, 2, 4, or 8. - per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, - `input_fn` is invoked once on each host. With the per-core input pipeline - configuration, it is invoked once for each core. - With a global batch size `train_batch_size` in `TPUEstimator` constructor, - the batch size for each shard is `train_batch_size` // #hosts in the - `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is - `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only - invoked once on host 0 and the tensors are broadcasted to all other - replicas. The batch size equals to train_batch_size`. With the per-core - input pipeline configuration, the shard batch size is also - `train_batch_size` // #cores. - Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. - tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred - within TPUEstimator, however when using ClusterSpec propagation in more - esoteric cluster configurations, you may need to specify the job name as a - string. - initial_infeed_sleep_secs: The number of seconds the infeed thread should - wait before enqueueing the first batch. This helps avoid timeouts for - models that require a long compilation time. - input_partition_dims: A nested list to describe the partition dims - for all the tensors from input_fn(). The structure of - input_partition_dims must match the structure of `features` and - `labels` from input_fn(). The total number of partitions must match - `num_cores_per_replica`. For example, if input_fn() returns two tensors: - images with shape [N, H, W, C] and labels [N]. - input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 - pieces and feed into 4 TPU cores. labels tensor are directly broadcasted - to all the TPU cores since the partition dims is `None`. - Current limitations: This feature is only supported with the PER_HOST_V2 - input mode. - - Raises: - ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16. - """ - - def __new__(cls, - iterations_per_loop=2, - num_shards=None, - num_cores_per_replica=None, - per_host_input_for_training=True, - tpu_job_name=None, - initial_infeed_sleep_secs=None, - input_partition_dims=None): - - # Check iterations_per_loop. - util_lib.check_positive_integer(iterations_per_loop, - 'TPUConfig iterations_per_loop') - - # Check num_shards. - if num_shards is not None: - util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - - if input_partition_dims is not None: - if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: - raise ValueError( - 'input_partition_dims must be a list/tuple with one or two' - ' elements.') - - if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: - raise ValueError( - 'input_partition_dims is only supported in PER_HOST_V2 mode.') - - if num_cores_per_replica is None: - raise ValueError( - 'input_partition_dims requires setting num_cores_per_replica.') - - # Check num_cores_per_replica - if num_cores_per_replica is not None: - if num_cores_per_replica not in [1, 2, 4, 8, 16]: - raise ValueError( - 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format( - str(num_cores_per_replica))) - - # per_host_input_for_training may be True, False, or integer in [1..3]. - # Map legacy values (True, False) to numeric values. - if per_host_input_for_training is False: - per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1 - elif per_host_input_for_training is True: - per_host_input_for_training = InputPipelineConfig.PER_HOST_V1 - - # Check initial_infeed_sleep_secs. - if initial_infeed_sleep_secs: - util_lib.check_positive_integer(initial_infeed_sleep_secs, - 'TPUConfig initial_infeed_sleep_secs') - - tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() - - return super(TPUConfig, cls).__new__( - cls, - iterations_per_loop=iterations_per_loop, - num_shards=num_shards, - num_cores_per_replica=num_cores_per_replica, - per_host_input_for_training=per_host_input_for_training, - tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs, - input_partition_dims=input_partition_dims) - - -class RunConfig(run_config_lib.RunConfig): - """RunConfig with TPU support.""" - - def __init__(self, - tpu_config=None, - evaluation_master=None, - master=None, - cluster=None, - **kwargs): - """Constructs a RunConfig. - - Args: - tpu_config: the TPUConfig that specifies TPU-specific configuration. - evaluation_master: a string. The address of the master to use for eval. - Defaults to master if not set. - master: a string. The address of the master to use for training. - cluster: a ClusterResolver - **kwargs: keyword config parameters. - - Raises: - ValueError: if cluster is not None and the provided session_config has a - cluster_def already. - """ - super(RunConfig, self).__init__(**kwargs) - self._tpu_config = tpu_config or TPUConfig() - self._cluster = cluster - - # If user sets master and/or evaluation_master explicitly, including empty - # string '', take it. Otherwise, take the values set by parent class. - if master is not None: - if cluster is not None: - raise ValueError('Both master and cluster are set.') - self._master = master - else: - if cluster: - self._master = cluster.master() - - if evaluation_master is not None: - self._evaluation_master = evaluation_master - elif (not self._evaluation_master and - self.task_type != run_config_lib.TaskType.EVALUATOR): - # If the task type is EVALUATOR, it means some cluster manager sets the - # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG. - # - # Otherwise, it means user executes the code without external cluster - # manager. For that, we optimize the user experience by setting - # evaluation_master to master, unless user overwrites it. - self._evaluation_master = self._master - - # Set the ClusterSpec to use - if cluster: - self._cluster_spec = cluster.cluster_spec() - - # Merge the cluster_def into the ConfigProto. - if self._session_config is None: # pylint: disable=access-member-before-definition - self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) - if self._session_config.HasField('cluster_def'): - raise ValueError( - 'You cannot provide a ClusterResolver and ' - 'session_config.cluster_def.') - if self._cluster_spec: - self._session_config.cluster_def.CopyFrom( - self._cluster_spec.as_cluster_def()) - - def _maybe_overwrite_session_config_for_distributed_training(self): - # Overrides the parent class session_config overwrite for between-graph. TPU - # runs with in-graph, which should not have device filter. Doing nothing - # ("pass") basically disables it. - pass - - @property - def evaluation_master(self): - return self._evaluation_master - - @property - def master(self): - return self._master - - @property - def tpu_config(self): - return self._tpu_config - - @property - def cluster(self): - return self._cluster - - def replace(self, **kwargs): - if 'tpu_config' not in kwargs: - return super(RunConfig, self).replace(**kwargs) - - tpu_config = kwargs.pop('tpu_config') - new_instance = super(RunConfig, self).replace(**kwargs) - new_instance._tpu_config = tpu_config # pylint: disable=protected-access - return new_instance - - -def _get_tpu_job_name_from_tf_config(): - """Extracts the TPU job name from TF_CONFIG env variable.""" - # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster - # spec propagation. - tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) - tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) - if tpu_job_name: - logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) - return tpu_job_name +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_config import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index ed1e0f0401a96c34e6ff9323685857b64e10bd14..b77b010cba6bf32c3b6d170bc522eebfb6a04f77 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -1,763 +1,23 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from contextlib import contextmanager -import copy - -from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding -from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.platform import tf_logging as logging - - -_DEFAULT_JOB_NAME = 'tpu_worker' -_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' -_LOCAL_MASTERS = ('', 'local') -_NUM_CORES_TO_COMPUTATION_SHAPE = { - 1: [1, 1, 1], - 2: [1, 1, 2], - 4: [1, 2, 2], - 8: [2, 2, 2], - 16: [4, 2, 2], -} - - -class TPUContext(object): - """A context that holds the current configuration of the TPU computation.""" - - def __init__(self, - internal_ctx, - input_device=None, - invocation_index=None, - call_from_input_fn=True): - self._internal_ctx = internal_ctx - self._input_device = input_device - self._invocation_index = invocation_index - self._call_from_input_fn = call_from_input_fn - - def current_input_fn_deployment(self): - """The configuration of the current input_fn invocation. - - The configuration depends on `TPUConfig.per_host_input_for_training`. See - `TPUConfig` for details. - - Only set in params dict of input_fn - - Returns: - A tuple of - 1. Device spec string: String, is the current CPU host where the - input_fn is invoked. - 2. Current invocation index: Int, 0-based index of the input_fn - invocation. See next item for details. - 3. Total invocation count: Int, the total number of times to invoke the - input_fn on all CPU hosts. Each invocation will be passed with a new - `TPUContext` instance with current invocation index set properly. - 4. Total number of replicas consumed by current_invocation: Int, the - number of replicas fed by the data returned by current input_fn. For - example, for per_core input pipeline deployment - and non-model-parallelism, total invocation count is equal to - the number of cores in the system and num replicas consumed by - current invocation is 1. For per-host v2 input pipeline deployment, - total invocation count is equal to the number of hosts in the system - and num replicas consumed by current invocation is equal to number of - cores per host. - - Raises: - RuntimeError: If this method must not be called from input_fn. - """ - if not self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' model_fn.') - - if self._internal_ctx.is_input_sharded_per_core(): - total_invocation_count = (self._internal_ctx.num_hosts - * self._internal_ctx.num_of_replicas_per_host) - replicas_consumed = 1 - elif self._internal_ctx.is_input_broadcast_with_iterators(): - total_invocation_count = 1 - replicas_consumed = self._internal_ctx.num_replicas - else: - total_invocation_count = self._internal_ctx.num_hosts - replicas_consumed = self._internal_ctx.num_of_replicas_per_host - return (self._input_device, self._invocation_index, - total_invocation_count, replicas_consumed) - - @property - def num_replicas(self): - """The total number of replicas. - - For non-model-parallelism, num_replicas should be the total num of TPU - cores in the system. - - Returns: - The number of replicas. - """ - return self._internal_ctx.num_replicas - - @property - def num_hosts(self): - """The number of hosts for the TPU system.""" - return self._internal_ctx.num_hosts - - @property - def current_host(self): - """The current host index for the TPU system.""" - return self._invocation_index - - @property - def num_of_replicas_per_host(self): - """The number of replicas for each host.""" - if self._internal_ctx.model_parallelism_enabled: - raise ValueError( - 'num_of_replicas_per_host is not supported for model_parallelism') - return self._internal_ctx.num_of_replicas_per_host - - @property - def device_assignment(self): - """Returns device_assignment object.""" - if self._call_from_input_fn: - raise RuntimeError('This TPUContext instance must not be called from' - ' input_fn.') - return self._internal_ctx.device_assignment - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - # Note that: For the non-model parallelism, the mapping could be - # a random permutation. The order should not matter in most cases - # as far as model is replicated to all cores in the system. - return self._internal_ctx.device_for_replica(replica_id) - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function. - - The place function takes host_id as the input and returns the TF device - for the correspoding host. - """ - - def _placement_function(host_id): - """Return the host device given host_id.""" - return self._internal_ctx.tpu_host_placement_function(host_id=host_id) - - return _placement_function - - -class _InternalTPUContext(object): - """A context holds immutable states of TPU computation. - - This immutable object holds TPUEstimator config, train/eval batch size, and - `TPUEstimator.use_tpu`, which is expected to be passed around. It also - provides utility functions, based on the current state, to determine other - information commonly required by TPU computation, such as TPU device names, - TPU hosts, shard batch size, etc. - - if eval_on_tpu is False, then execution of eval on TPU is disabled. - if eval_on_tpu is True, but use_tpu is False, a warning is issued, - and TPU execution is disabled for all modes. - - N.B. As `mode` is not immutable state in Estimator, but essential to - distinguish between TPU training and evaluation, a common usage for - _InternalTPUContext with `mode` is as follows: - ``` - with _ctx.with_mode(mode) as ctx: - if ctx.is_running_on_cpu(): - ... - ``` - """ - - def __init__(self, - config, - train_batch_size, - eval_batch_size, - predict_batch_size, - use_tpu, - eval_on_tpu=True, - embedding_config_spec=None): - self._config = config - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size - self._predict_batch_size = predict_batch_size - self._use_tpu = use_tpu - logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu) - if not use_tpu and eval_on_tpu: - logging.warning('eval_on_tpu ignored because use_tpu is False.') - - self._eval_on_tpu = eval_on_tpu - self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.num_cores_per_replica) - self._mode = None - num_cores_per_replica = config.tpu_config.num_cores_per_replica - if self._model_parallelism_enabled: - self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ - num_cores_per_replica] - else: - self._computation_shape = None - self._lazy_tpu_system_metadata_dict = {} # key by master address - self._lazy_device_assignment_dict = {} # key by master address - self._lazy_validation_dict = {} # key by ModeKeys - self._embedding_config_spec = embedding_config_spec - self._lazy_embedding_config_dict = {} # key by master address - - def _assert_mode(self): - if self._mode is None: - raise RuntimeError( - '`mode` needs to be set via contextmanager `with_mode`.') - return self._mode - - @contextmanager - def with_mode(self, mode): - # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, - # such as _lazy_tpu_system_metadata_dict between new copy and the original - # one. Note that all lazy states stored in properties _lazy_foo are sort of - # immutable as they should be same for the process lifetime. - new_ctx = copy.copy(self) - new_ctx._mode = mode # pylint: disable=protected-access - yield new_ctx - - @property - def mode(self): - return self._assert_mode() - - def _get_master_address(self): - mode = self._assert_mode() - config = self._config - master = ( - config.master - if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) - return master - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - cluster_def = None - if (self._config.session_config and - self._config.session_config.cluster_def.job): - cluster_def = self._config.session_config.cluster_def - - # pylint: disable=protected-access - tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata( - master, - cluster_def=cluster_def, - query_topology=self.model_parallelism_enabled)) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - def _get_device_assignment(self): - """Gets the (maybe cached) TPU device assignment.""" - master = self._get_master_address() - device_assignment = self._lazy_device_assignment_dict.get(master) - if device_assignment is not None: - return device_assignment - - tpu_system_metadata = self._get_tpu_system_metadata() - - device_assignment = tpu_device_assignment.device_assignment( - tpu_system_metadata.topology, - computation_shape=self._computation_shape, - num_replicas=self.num_replicas) - - logging.info('num_cores_per_replica: %s', - str(self._config.tpu_config.num_cores_per_replica)) - logging.info('computation_shape: %s', str(self._computation_shape)) - logging.info('num_replicas: %d', self.num_replicas) - logging.info('device_assignment.topology.device_coordinates: %s', - str(device_assignment.topology.device_coordinates)) - logging.info('device_assignment.core_assignment: %s', - str(device_assignment.core_assignment)) - - self._lazy_device_assignment_dict[master] = device_assignment - return device_assignment - - @property - def embedding_config(self): - """Returns the embedding config based on current mode.""" - master = self._get_master_address() - if master in self._lazy_embedding_config_dict: - embedding_config = self._lazy_embedding_config_dict[master] - else: - embedding_config = None - if self._use_tpu and self._embedding_config_spec: - embedding_config = _tpu_estimator_embedding.EmbeddingConfig( - self._embedding_config_spec, self._train_batch_size, - self._eval_batch_size, self.num_hosts, self.num_cores, master) - if not embedding_config.has_embedding_tables(): - embedding_config = None - self._lazy_embedding_config_dict[master] = embedding_config - - if embedding_config is not None: - mode = self._assert_mode() - # Dynamically attach tpu_embedding based on mode. With - # this, we could keep embedding_config immutable but call site always - # accesses the unified API '.tpu_embedding'. - embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode) - return embedding_config - - @property - def model_parallelism_enabled(self): - return self._model_parallelism_enabled - - @property - def input_partition_dims(self): - return self._config.tpu_config.input_partition_dims - - @property - def device_assignment(self): - return (self._get_device_assignment() - if self._model_parallelism_enabled else None) - - @property - def num_of_cores_per_host(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_of_cores_per_host - - @property - def num_cores(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_cores - - @property - def num_of_replicas_per_host(self): - """Return the number of replicas per host.""" - if self.model_parallelism_enabled: - return self.num_replicas // self.num_hosts - else: - return self.num_of_cores_per_host - - @property - def num_replicas(self): - num_cores_in_system = self.num_cores - - if self.model_parallelism_enabled: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - if num_cores_per_replica > num_cores_in_system: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the total num of ' - 'TPU cores in the system. num_cores_per_replica: {}, num cores ' - 'in the system: {}'.format(num_cores_per_replica, - num_cores_in_system)) - - if num_cores_in_system % num_cores_per_replica != 0: - raise RuntimeError( - 'The num of cores in the system ({}) is not divisible by the num ' - 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica. This should never happen!'.format( - num_cores_in_system, num_cores_per_replica)) - - return num_cores_in_system // num_cores_per_replica - else: - return num_cores_in_system - - @property - def num_hosts(self): - metadata = self._get_tpu_system_metadata() - return metadata.num_hosts - - @property - def config(self): - return self._config - - def is_input_sharded_per_core(self): - """Return true if input_fn is invoked per-core (other than per-host).""" - mode = self._assert_mode() - return (mode == model_fn_lib.ModeKeys.TRAIN and - (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1)) - - def is_input_per_host_with_iterators(self): - """Return true if input_fn should be run in the per-host v2 config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_HOST_V2) - - def is_input_broadcast_with_iterators(self): - """Return true if input_fn should be run in the full_replicae config.""" - return (self._config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.BROADCAST) - - def is_running_on_cpu(self, is_export_mode=False): - """Determines whether the input_fn and model_fn should be invoked on CPU. - - This API also validates user provided configuration, such as batch size, - according the lazy initialized TPU system metadata. - - Args: - is_export_mode: Indicates whether the current mode is for exporting the - model, when mode == PREDICT. Only with this bool, we could - tell whether user is calling the Estimator.predict or - Estimator.export_savedmodel, which are running on TPU and CPU - respectively. Parent class Estimator does not distinguish these two. - - Returns: - bool, whether current input_fn or model_fn should be running on CPU. - - Raises: - ValueError: any configuration is invalid. - """ - - is_running_on_cpu = self._is_running_on_cpu(is_export_mode) - if not is_running_on_cpu: - self._validate_tpu_configuration() - return is_running_on_cpu - - def _is_running_on_cpu(self, is_export_mode): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - mode = self._assert_mode() - - if not self._use_tpu: - return True - - if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu: - logging.info('_is_running_on_cpu: eval_on_tpu disabled') - return True - - if is_export_mode: - return True - - return False - - @property - def global_batch_size(self): - mode = self._assert_mode() - if mode == model_fn_lib.ModeKeys.TRAIN: - return self._train_batch_size - elif mode == model_fn_lib.ModeKeys.EVAL: - return self._eval_batch_size - elif mode == model_fn_lib.ModeKeys.PREDICT: - return self._predict_batch_size - else: - return None - - @property - def batch_size_for_input_fn(self): - """Returns the shard batch size for `input_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU - if self.is_input_sharded_per_core() or ( - self.is_input_per_host_with_iterators()): - return global_batch_size // self.num_replicas - else: - return global_batch_size // self.num_hosts - - @property - def batch_size_for_model_fn(self): - """Returns the shard batch size for `model_fn`.""" - global_batch_size = self.global_batch_size - - if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): - return global_batch_size - - # On TPU. always sharded per shard. - return global_batch_size // self.num_replicas - - @property - def master_job(self): - """Returns the job name to use to place TPU computations on. - - Returns: - A string containing the job name, or None if no job should be specified. - - Raises: - ValueError: If the user needs to specify a tpu_job_name, because we are - unable to infer the job name automatically, or if the user-specified job - names are inappropriate. - """ - run_config = self._config - # If the user specifies the tpu_job_name, use that. - if run_config.tpu_config.tpu_job_name: - return run_config.tpu_config.tpu_job_name - - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - mode = self._assert_mode() - master = ( - run_config.evaluation_master - if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) - if master in _LOCAL_MASTERS: - return None - - if (not run_config.session_config or - not run_config.session_config.cluster_def.job): - return _DEFAULT_JOB_NAME - cluster_def = run_config.session_config.cluster_def - job_names = set([job.name for job in cluster_def.job]) - if _DEFAULT_JOB_NAME in job_names: - # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. - raise ValueError('Currently, tpu_worker is not an allowed job name.') - if len(job_names) == 1: - return cluster_def.job[0].name - if len(job_names) == 2: - if _DEFAULT_COORDINATOR_JOB_NAME in job_names: - job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) - return job_names.pop() - # TODO(b/67716447): Include more sophisticated heuristics. - raise ValueError( - 'Could not infer TPU job name. Please specify a tpu_job_name as part ' - 'of your TPUConfig.') - - @property - def tpu_host_placement_function(self): - """Returns the TPU host place function.""" - - master = self.master_job - - def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name - """Return the host device given replica_id or host_id.""" - assert _sentinal is None - if replica_id is not None and host_id is not None: - raise RuntimeError( - 'replica_id and host_id can have only one non-None value.') - - if master is None: - return '/replica:0/task:0/device:CPU:0' - else: - if replica_id is not None: - if self.model_parallelism_enabled: - return self.device_assignment.host_device( - replica=replica_id, job=master) - else: - host_id = replica_id / self.num_of_cores_per_host - - return '/job:%s/task:%d/device:CPU:0' % (master, host_id) - - return _placement_function - - @property - def tpu_device_placement_function(self): - """Returns a TPU device placement Fn.""" - master = self.master_job - job_device = '' if master is None else ('/job:%s' % master) - - def _placement_function(i): - if self.model_parallelism_enabled: - return self.device_assignment.tpu_device(replica=i, job=master) - else: - num_of_cores_per_host = self.num_of_cores_per_host - host_id = i / num_of_cores_per_host - ordinal_id = i % num_of_cores_per_host - return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) - - return _placement_function - - def tpu_ordinal_function(self, host_id): - """Returns the TPU ordinal fn.""" - - def _tpu_ordinal_function(shard_index_in_host): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - shard_index_in_host: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - if self.model_parallelism_enabled: - # We put both enqueue/dequeue ops at tpu.core(0) in each replica. - replica = self.device_assignment.lookup_replicas(host_id, - 0)[shard_index_in_host] - return self.device_assignment.tpu_ordinal(replica=replica) - else: - return shard_index_in_host % self.num_of_cores_per_host - - return _tpu_ordinal_function - - def _validate_tpu_configuration(self): - """Validates the configuration based on the TPU system metadata.""" - mode = self._assert_mode() - if self._lazy_validation_dict.get(mode): - return - - # All following information is obtained from TPU system metadata. - num_cores = self.num_cores - num_replicas = self.num_replicas - num_hosts = self.num_hosts - - if not num_cores: - tpu_system_metadata = self._get_tpu_system_metadata() - raise RuntimeError( - 'Cannot find any TPU cores in the system. Please double check ' - 'Tensorflow master address and TPU worker(s). Available devices ' - 'are {}.'.format(tpu_system_metadata.devices)) - - if self._config.tpu_config.num_shards: - user_provided_num_replicas = self._config.tpu_config.num_shards - if user_provided_num_replicas != num_replicas: - message = ( - 'TPUConfig.num_shards is not set correctly. According to TPU ' - 'system metadata for Tensorflow master ({}): num_replicas should ' - 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' - 'be the total num of TPU cores in the system. For ' - 'model-parallelism, the total number of TPU cores should be ' - 'num_cores_per_replica * num_replicas. Please set it ' - 'accordingly or leave it as `None`'.format( - self._get_master_address(), num_replicas, - user_provided_num_replicas)) - - raise ValueError(message) - - if self._config.tpu_config.num_cores_per_replica: - num_cores_per_replica = self._config.tpu_config.num_cores_per_replica - num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host - if num_cores_per_replica > num_cores_per_host: - raise ValueError( - 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.num_cores_per_replica, is larger than the ' - 'num_cores_per_host. num_cores_per_replica: {}, ' - 'num_cores_per_host: {}'.format(num_cores_per_replica, - num_cores_per_host)) - - if mode == model_fn_lib.ModeKeys.TRAIN: - if (self._train_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'train batch size {} must be divisible by number of replicas {}' - .format(self._train_batch_size, num_replicas)) - - elif mode == model_fn_lib.ModeKeys.EVAL: - if self._eval_batch_size is None: - raise ValueError( - 'eval_batch_size in TPUEstimator constructor cannot be `None`' - 'if .evaluate is running on TPU.') - if (self._eval_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'eval batch size {} must be divisible by number of replicas {}' - .format(self._eval_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.evaluate should be running on single TPU' - ' instead of a Pod.') - else: - assert mode == model_fn_lib.ModeKeys.PREDICT - if self._predict_batch_size is None: - raise ValueError( - 'predict_batch_size in TPUEstimator constructor should not be ' - '`None` if .predict is running on TPU.') - if (self._predict_batch_size % num_replicas != 0 and - not self.is_input_broadcast_with_iterators()): - raise ValueError( - 'predict batch size {} must be divisible by number of replicas {}' - .format(self._predict_batch_size, num_replicas)) - if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): - raise ValueError( - 'TPUEstimator.predict should be running on single TPU worker. ' - 'got {}.'.format(num_hosts)) - - # Record the state "validated" into lazy dictionary. - self._lazy_validation_dict[mode] = True - - def device_for_replica(self, replica_id): - """Returns the tuple of (CPU device and device ordinal) for replica. - - This should be used for full replicate for non-model-parallelism. - - Args: - replica_id: Int, the replica index. - - Returns: - A tuple of device spec for CPU device and int device ordinal. - """ - master = self.master_job - - if self.model_parallelism_enabled: - return (self.device_assignment.host_device( - replica=replica_id, job=master), - self.device_assignment.tpu_ordinal(replica=replica_id)) - - job_device = '' if master is None else ('/job:%s' % master) - - num_of_replicas_per_host = self.num_of_replicas_per_host - host_id = replica_id / num_of_replicas_per_host - ordinal_id = replica_id % num_of_replicas_per_host - - host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) - return (host_device, ordinal_id) - - -class _OneCoreTPUContext(_InternalTPUContext): - """Special _InternalTPUContext for one core usage.""" - - def __init__(self, config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu): - - super(_OneCoreTPUContext, self).__init__( - config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - def _get_tpu_system_metadata(self): - """Gets the (maybe cached) TPU system metadata.""" - master = self._get_master_address() - tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) - if tpu_system_metadata is not None: - return tpu_system_metadata - - tpu_system_metadata = ( - tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access - num_cores=1, - num_hosts=1, - num_of_cores_per_host=1, - topology=None, - devices=[])) - - self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata - return tpu_system_metadata - - -def _get_tpu_context(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu, - embedding_config_spec): - """Returns an instance of `_InternalTPUContext`.""" - - if (config.tpu_config.num_shards == 1 and - config.tpu_config.num_cores_per_replica is None): - if embedding_config_spec is not None: - raise ValueError('Setting TPUConfig.num_shards==1 is unsupported ' - 'when embedding_config_spec is not None.') - logging.warning( - 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' - 'Please fix as soon as possible (leaving num_shards as None.)') - return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu) - - return _InternalTPUContext(config, train_batch_size, eval_batch_size, - predict_batch_size, use_tpu, eval_on_tpu, - embedding_config_spec) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_context import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 0e4597bd6fae500c93f74fcb1b16a39739d2310c..cb38a8f1a6bee3c2adfbefc203c1d143303c3368 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1,10 +1,10 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,1076 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TPU embedding APIs.""" +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import math -import re -import six - -from tensorflow.contrib.framework.python.framework import experimental -from tensorflow.contrib.tpu.ops import gen_tpu_ops -from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import partitioned_variables -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables - -TRAINING = elc.TPUEmbeddingConfiguration.TRAINING -INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE - - -class TableConfig( - collections.namedtuple( - 'TableConfig', - ['vocabulary_size', 'dimension', 'initializer', 'combiner'])): - """Embedding table configuration.""" - - @experimental - def __new__(cls, - vocabulary_size, - dimension, - initializer=None, - combiner='mean'): - """Embedding table configuration. - - Args: - vocabulary_size: Number of vocabulary (/rows) in the table. - dimension: The embedding dimension. - initializer: A variable initializer function to be used in embedding - variable initialization. If not specified, defaults to - `tf.truncated_normal_initializer` with mean `0.0` and standard deviation - `1/sqrt(dimension)`. - combiner: A string specifying how to reduce if there are multiple entries - in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with - 'mean' the default. 'sqrtn' often achieves good accuracy, in particular - with bag-of-words columns. For more information, see - `tf.nn.embedding_lookup_sparse`. - - Returns: - `TableConfig`. - - Raises: - ValueError: if `vocabulary_size` is not positive integer. - ValueError: if `dimension` is not positive integer. - ValueError: if `initializer` is specified and is not callable. - ValueError: if `combiner` is not supported. - """ - if not isinstance(vocabulary_size, int) or vocabulary_size < 1: - raise ValueError('Invalid vocabulary_size {}.'.format(vocabulary_size)) - - if not isinstance(dimension, int) or dimension < 1: - raise ValueError('Invalid dimension {}.'.format(dimension)) - - if (initializer is not None) and (not callable(initializer)): - raise ValueError('initializer must be callable if specified.') - if initializer is None: - initializer = init_ops.truncated_normal_initializer( - mean=0.0, stddev=1 / math.sqrt(dimension)) - - if combiner not in ('mean', 'sum', 'sqrtn'): - raise ValueError('Invalid combiner {}'.format(combiner)) - - return super(TableConfig, cls).__new__(cls, vocabulary_size, dimension, - initializer, combiner) - - -AdamSlotVariableNames = collections.namedtuple( - 'AdamSlotVariableNames', ['m', 'v']) - -AdagradSlotVariableName = collections.namedtuple( - 'AdagradSlotVariableName', ['accumulator']) - -AdamSlotVariables = collections.namedtuple( - 'AdamSlotVariables', ['m', 'v']) - -AdagradSlotVariable = collections.namedtuple( - 'AdagradSlotVariable', ['accumulator']) - -VariablesAndOps = collections.namedtuple( - 'VariablesAndOps', - ['embedding_variables_by_table', 'slot_variables_by_table', - 'load_ops', 'retrieve_ops'] -) - - -# TODO(shizhiw): Factor `use_gradient_accumulation` and -# `pipeline_execution_with_tensor_core` out of `_OptimizationParameters`. -class _OptimizationParameters(object): - """Parameters common to all optimizations.""" - - def __init__(self, learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core): - self.learning_rate = learning_rate - self.use_gradient_accumulation = use_gradient_accumulation - self.pipeline_execution_with_tensor_core = ( - pipeline_execution_with_tensor_core) - - -class AdagradParameters(_OptimizationParameters): - """Optimization parameters for Adagrad.""" - - def __init__(self, learning_rate, initial_accumulator, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adagrad. - - Args: - learning_rate: used for updating embedding table. - initial_accumulator: initial accumulator for Adagrad. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdagradParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.initial_accumulator = initial_accumulator - - -class AdamParameters(_OptimizationParameters): - """Optimization parameters for Adam.""" - - def __init__(self, learning_rate, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - lazy_adam=True, - sum_inside_sqrt=True, - use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - """Optimization parameters for Adam. - - Args: - learning_rate: a floating point value. The learning rate. - beta1: A float value. - The exponential decay rate for the 1st moment estimates. - beta2: A float value. - The exponential decay rate for the 2nd moment estimates. - epsilon: A small constant for numerical stability. - lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. - Please see `optimization_parameters.proto` for details. - sum_inside_sqrt: This improves training speed. Please see - `optimization_parameters.proto` for details. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - super(AdamParameters, self).__init__(learning_rate, - use_gradient_accumulation, - pipeline_execution_with_tensor_core) - self.beta1 = beta1 - self.beta2 = beta2 - self.epsilon = epsilon - self.lazy_adam = lazy_adam - self.sum_inside_sqrt = sum_inside_sqrt - - -class StochasticGradientDescentParameters(_OptimizationParameters): - """Optimization parameters for stochastic gradient descent. - - Args: - learning_rate: a floating point value. The learning rate. - use_gradient_accumulation: setting this to `True` makes embedding - gradients calculation more accurate but slower. Please see - `optimization_parameters.proto` for details. - pipeline_execution_with_tensor_core: setting this to `True` makes training - faster, but trained model will be different if step N and step N+1 - involve the same set of embedding ID. Please see - `tpu_embedding_configuration.proto` for details. - """ - - def __init__(self, learning_rate, use_gradient_accumulation=False, - pipeline_execution_with_tensor_core=True): - super(StochasticGradientDescentParameters, self).__init__( - learning_rate, use_gradient_accumulation, - pipeline_execution_with_tensor_core) - - -class TPUEmbedding(object): - """API for using TPU for embedding. - - Example: - ``` - table_config_user = tpu_embedding.TableConfig( - vocabulary_size=4, dimension=2, - initializer=initializer, combiner='mean') - table_to_config_dict = {'video': table_config_video, - 'user': table_config_user} - feature_to_table_dict = {'watched': 'video', - 'favorited': 'video', - 'friends': 'user'} - batch_size = 4 - num_hosts = 1 - optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) - mode = tpu_embedding.TRAINING - embedding = tpu_embedding.TPUEmbedding( - table_to_config_dict, feature_to_table_dict, - batch_size, num_hosts, mode, optimization_parameters) - - batch_size_per_core = embedding.batch_size_per_core - sparse_features_list = [] - for host in hosts: - with ops.device(host): - for _ in range(embedding.num_cores_per_host): - sparse_features = {} - sparse_features['watched'] = sparse_tensor.SparseTensor(...) - sparse_features['favorited'] = sparse_tensor.SparseTensor(...) - sparse_features['friends'] = sparse_tensor.SparseTensor(...) - sparse_features_list.append(sparse_features) - - enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) - embedding_variables_and_ops = embedding.create_variables_and_ops() - - def computation(): - activations = embedding.get_activations() - loss = compute_loss(activations) - - base_optimizer = gradient_descent.GradientDescentOptimizer( - learning_rate=1) - cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( - base_optimizer) - - train_op = cross_shard_optimizer.minimize(loss) - # `train_op` and `send_gradients_op` must happen in order. - with ops.control_dependencies([train_op]): - send_gradients_op = embedding.generate_send_gradients_op() - with ops.control_dependencies([send_gradients_op]): - loss = array_ops.identity(loss) - - loss = tpu.shard(computation, - num_shards=embedding.num_cores) - - with self.test_session() as sess: - sess.run(tpu.initialize_system(embedding_config= - embedding.config_proto)) - sess.run(variables.global_variables_initializer()) - sess.run(embedding.init_ops) - sess.run(embedding_variables_and_ops.load_ops) - sess.run(enqueue_ops) - loss_val = sess.run(loss) - ``` - """ - - # TODO(shizhiw): Instead of `feature_to_table_dict` which maps to table - # name, consider `feature_to_config_dict` which maps to `FeatureConfig`. - # `FeatureConfig` could have fields other than table name. For example, it - # could have a field to indicate that the feature should not be used to - # update embedding table (cr/204852758, cr/204940540). Also, this can support - # different combiners for different features within the same table. - # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it - # to `FeatureConfig`? - - # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and - # `feature_to_table_dict` lists of `TableSpec` and `FeatureSpec` respectively? - - # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate - # for-loops around construction of inputs. - - # `optimization_parameter` applies to all tables. If the need arises, - # we can add `optimization_parameters` to `TableConfig` to override this - # global setting. - @experimental - def __init__(self, - table_to_config_dict, - feature_to_table_dict, - batch_size, - mode, - master, - optimization_parameters=None): - """API for using TPU for embedding lookups. - - Args: - table_to_config_dict: A dictionary mapping from string of table name to - `TableConfig`. Table refers to an embedding table, e.g. `params` - argument to `tf.nn.embedding_lookup_sparse()`. - feature_to_table_dict: A dictionary mapping from string of feature name - to string of table name. Feature refers to ids to lookup in embedding - table, e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. - batch_size: An `int` representing the global batch size. - mode: `TRAINING` or `INFERENCE`. - master: A `string` representing the TensorFlow master to use. - optimization_parameters: `AdagradParameters`, `AdamParameters`, - `Stochasticgradientdescentparameters`. Must be set in training and must - be `None` in inference. - - Raises: - ValueError: if any input is invalid. - """ - _validate_table_to_config_dict(table_to_config_dict) - # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. - self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) - self._combiners = _create_combiners(self._table_to_config_dict) - - _validate_feature_to_table_dict(table_to_config_dict, feature_to_table_dict) - self._feature_to_table_dict = _create_ordered_dict(feature_to_table_dict) - self._table_to_features_dict = _create_table_to_features_dict( - self._feature_to_table_dict) - - self._batch_size = batch_size - - self._master = master - self._tpu_system_metadata = ( - tpu_system_metadata_lib._query_tpu_system_metadata(self._master)) # pylint: disable=protected-access - if self._tpu_system_metadata.num_cores == 0: - raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' - 'TPUs.'.format(self._master)) - self._num_hosts = self._tpu_system_metadata.num_hosts - self._hosts = [device.name for device in self._tpu_system_metadata.devices - if 'device:CPU:' in device.name] - self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host - self._num_cores = self._tpu_system_metadata.num_cores - - _validate_batch_size(self._batch_size, self._num_cores) - self._batch_size_per_core = self._batch_size // self._num_cores - - self._init_ops = [] - - # TODO(shizhiw): remove `mode`? - if mode == TRAINING: - _validate_optimization_parameters(optimization_parameters) - self._optimization_parameters = optimization_parameters - elif mode == INFERENCE: - if optimization_parameters is not None: - raise ValueError('`optimization_parameters` should be `None` ' - 'for inference mode.') - self._optimization_parameters = ( - StochasticGradientDescentParameters(1.)) - else: - raise ValueError('`mode` only supports {} and {}; got {}.' - .format(TRAINING, INFERENCE, mode)) - self._mode = mode - - # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` - # and create special handler for inference that inherits from - # StochasticGradientDescentHandler with more user-friendly error message - # on get_slot(). - self._optimizer_handler = _get_optimization_handler( - self._optimization_parameters) - - dummy_table_variables_init_op = self._create_dummy_table_variables() - self._init_ops.append(dummy_table_variables_init_op) - - self._config_proto = self._create_config_proto() - - @property - def hosts(self): - """A list of device names for CPU hosts. - - Returns: - A list of device names for CPU hosts. - """ - return copy.copy(self._hosts) - - # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and - # to be consistent with `tpu_embedding_configuration.proto`. - @property - def num_cores_per_host(self): - """Number of TPU cores on a CPU host. - - Returns: - Number of TPU cores on a CPU host. - """ - return self._num_cores_per_host - - @property - def num_cores(self): - """Total number of TPU cores on all hosts. - - Returns: - Total number of TPU cores on all hosts. - """ - return self._num_cores - - @property - def batch_size_per_core(self): - """Batch size for each TPU core. - - The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` - must have batch dimension equal to this. - - Returns: - Batch size for each TPU core. - """ - return self._batch_size_per_core - - @property - def config_proto(self): - """Create embedding config proto for `tpu.initialize_system()`. - - Returns: - an `TPUEmbeddingConfiguration` proto describing the desired - configuration of the hardware embedding lookup tables, which - is passed to `tpu.initialize_system()`. - """ - return self._config_proto - - @property - def init_ops(self): - """Initialization ops for TPU embedding. - - It must be called after all global variables have been initialized, - i.e. after `global_variables_initializer()`, as it loads embedding - tables into TPU. - - Returns: - A list of ops. - """ - return self._init_ops - - @property - def feature_to_table_dict(self): - return copy.copy(self._feature_to_table_dict) - - def _create_config_proto(self): - """Create `TPUEmbeddingConfiguration`.""" - config_proto = elc.TPUEmbeddingConfiguration() - for table in self._table_to_config_dict: - table_descriptor = config_proto.table_descriptor.add() - table_descriptor.name = table - - table_config = self._table_to_config_dict[table] - table_descriptor.vocabulary_size = table_config.vocabulary_size - table_descriptor.dimension = table_config.dimension - - features_for_table = self._table_to_features_dict[table] - table_descriptor.num_features = len(features_for_table) - - table_descriptor.optimization_parameters.learning_rate.constant = ( - self._optimization_parameters.learning_rate) - table_descriptor.optimization_parameters.use_gradient_accumulation = ( - self._optimization_parameters.use_gradient_accumulation) - self._optimizer_handler.set_optimization_parameters(table_descriptor) - - config_proto.mode = self._mode - config_proto.batch_size_per_tensor_core = self._batch_size_per_core - config_proto.num_hosts = self._num_hosts - config_proto.num_tensor_cores = self._num_cores - config_proto.sharding_strategy = elc.TPUEmbeddingConfiguration.DIV_DEFAULT - config_proto.pipeline_execution_with_tensor_core = ( - self._optimization_parameters.pipeline_execution_with_tensor_core) - - return config_proto - - def create_variables_and_ops(self, embedding_variable_name_by_table=None, - slot_variable_names_by_table=None): - """Create embedding and slot variables, with ops to load and retrieve them. - - Args: - embedding_variable_name_by_table: A dictionary mapping from string of - table name to string of embedding variable name. If `None`, - defaults from `get_default_slot_variable_names()` will be used. - slot_variable_names_by_table: A dictionary mapping from string of table - name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If - `None`, defaults from `get_default_slot_variable_names()` will be used. - - Returns: - `tpu_embedding.VariablesAndOps` with: - A dictionary mapping from string of table name to embedding variables, - A dictionary mapping from string of table name to AdagradSlotVariable, - AdamSlotVariables etc with slot variables, - A list of ops to load embedding and slot variables on CPU to TPU, - A list of ops to retrieve embedding and slot variables from TPU to CPU. - """ - embedding_variables_by_table = {} - slot_variables_by_table = {} - load_ops = [] - retrieve_ops = [] - for table in self._table_to_config_dict: - if embedding_variable_name_by_table: - embedding_variable_name = embedding_variable_name_by_table[table] - else: - embedding_variable_name = table - if slot_variable_names_by_table: - slot_variable_names = slot_variable_names_by_table[table] - else: - slot_variable_names = ( - self._optimizer_handler.get_default_slot_variable_names(table)) - - device_fn = _create_device_fn(self._hosts) - with ops.device(device_fn): - table_variables = _create_partitioned_variables( - name=embedding_variable_name, - num_hosts=self._num_hosts, - vocabulary_size=self._table_to_config_dict[table].vocabulary_size, - embedding_dimension=self._table_to_config_dict[table].dimension, - initializer=self._table_to_config_dict[table].initializer, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - embedding_variables_by_table[table] = table_variables - - slot_variables_for_table, load_ops_for_table, retrieve_ops_for_table = ( - self._optimizer_handler.create_variables_and_ops( - table, slot_variable_names, self._num_hosts, - self._table_to_config_dict[table], table_variables) - ) - slot_variables_by_table[table] = slot_variables_for_table - load_ops.extend(load_ops_for_table) - retrieve_ops.extend(retrieve_ops_for_table) - return VariablesAndOps(embedding_variables_by_table, - slot_variables_by_table, - load_ops, retrieve_ops) - - def _create_dummy_table_variables(self): - """Create dummy embedding table variables. - - The sole purpose of these dummy variables are to trigger gradient - calcuation wrt them so that the gradients wrt activation can be captured - and later sent to TPU embedding. - - Returns: - Initializer for these variables. - - Raises: - RuntimeError: if collection to store gradients already exists and is not - empty. - """ - self._dummy_table_variables = [] - # TODO(shizhiw): remove table id. - for table_id, table in enumerate(self._table_to_features_dict): - self._dummy_table_variables.append( - variable_scope.get_variable( - 'tpu_embedding_dummy_table_variable_%s' % table, - dtype=dtypes.float32, - shape=[1], - use_resource=True, - trainable=True, - # TODO(shizhiw): Remove these dummy variables as - # tensorflow optimizer creates slot variable for them which - # is undesirable. - # e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}. - # Explicitly specifying collections prevents this variable from - # being added to the GLOBAL_VARIABLES collection, so that Saver() - # ignores it. - collections=['tpu_embedding_dummy_table_variables'])) - - g = ops.get_default_graph() - table_gradients = g.get_collection_ref( - 'tpu_embedding_gradients_table_%d' % table_id) - if table_gradients: - raise RuntimeError( - 'tpu_embedding_gradients_table_%d is not empty.' % table_id) - table_gradients.extend([None] * len(self._table_to_features_dict[table])) - - return variables.variables_initializer( - self._dummy_table_variables, - name='tpu_embedding_dummy_table_variables_init') - - def generate_enqueue_ops(self, sparse_features_list): - """Generate enqueue ops. - - Args: - sparse_features_list: a list of dictionary mapping from string - of feature names to sparse tensor. Each dictionary is for one - TPU core. Dictionaries for the same core should be contiguous - on the list. - - Returns: - Ops to enqueue to TPU for embedding. - """ - self._validate_generate_enqueue_ops_sparse_features_list( - sparse_features_list) - return [ - self._generate_enqueue_op( - sparse_features, device_ordinal=i % self._num_cores_per_host) - for i, sparse_features in enumerate(sparse_features_list) - ] - - def _validate_generate_enqueue_ops_sparse_features_list( - self, sparse_features_list): - """Validate `sparse_features_list`.""" - if len(sparse_features_list) != self._num_cores: - raise ValueError('Length of `sparse_features_list` should match the ' - 'number of cores; ' - '`len(sparse_features_list)` is {}, ' - 'number of cores is {}.'.format( - len(sparse_features_list), self._num_cores)) - - feature_set = set(self._feature_to_table_dict.keys()) - contiguous_device = None - for i, sparse_features in enumerate(sparse_features_list): - used_feature_set = set(sparse_features.keys()) - - # Check features are valid. - missing_feature_set = feature_set - used_feature_set - if missing_feature_set: - raise ValueError('`sparse_features_list[{}]` misses a feature that is ' - 'in `feature_to_config_dict`: {}.'.format( - i, missing_feature_set)) - - extra_feature_set = used_feature_set - feature_set - if extra_feature_set: - raise ValueError('`sparse_features_list[{}]` has a feature that is not ' - 'in `feature_to_config_dict`: {}.'.format( - i, extra_feature_set)) - - device = None - device_feature = None - for feature, tensor in six.iteritems(sparse_features): - if not isinstance(tensor, sparse_tensor.SparseTensor): - raise ValueError('`sparse_features_list[{}]` has a feature that is ' - 'not mapped to `SparseTensor`. ' - '`feature`: {}, type: {}'.format( - i, feature, type(tensor))) - - # Check all features are on the same device. - if device is None: - device = tensor.op.device - device_feature = feature - else: - if device != tensor.op.device: - raise ValueError('Devices are different between features in ' - '`sparse_features_list[{}]`; ' - 'devices: {}, {}; features: {}, {}.'.format( - i, device, tensor.op.device, feature, - device_feature)) - - if i % self._num_cores_per_host: - if device != contiguous_device: - raise ValueError('We expect the `sparse_features` which are on the ' - 'same host to be contiguous in ' - '`sparse_features_list`, ' - '`sparse_features_list[{}]` is on device {}, ' - 'but is expected to be on device {}.'.format( - i, device, contiguous_device)) - else: - contiguous_device = device - - def _generate_enqueue_op(self, sparse_features, device_ordinal): - with ops.colocate_with(list(sparse_features.values())[0]): - sample_idcs, embedding_idcs, aggregation_weights = ( - self._format_for_tpu_embedding_sparse_batch(sparse_features)) - return tpu_ops.enqueue_tpu_embedding_sparse_batch( - sample_idcs, - embedding_idcs, - aggregation_weights, - combiners=self._combiners, - device_ordinal=device_ordinal) - - def _format_for_tpu_embedding_sparse_batch(self, sparse_features): - """Format sparse features for `enqueue_tpu_embedding_sparse_batch()`. - - Args: - sparse_features: a `Dict` of `SparseTensor`s for embedding. - - Returns: - Arguments for `enqueue_tpu_embedding_sparse_batch()`. - """ - - sample_idcs, embedding_idcs, aggregation_weights = list(), list(), list() - for table in self._table_to_features_dict: - sample_t, indices_t, weights_t = list(), list(), list() - - features = self._table_to_features_dict[table] - for i, feature in enumerate(features): - tensor = sparse_features[feature] - sample_indices = tensor.indices[:, 0] - embedding_indices = tensor.values - weights = array_ops.ones_like(embedding_indices) - sample_t.append(i * self._batch_size_per_core + sample_indices) - indices_t.append(embedding_indices) - weights_t.append(weights) - - sample_idcs.append( - math_ops.cast(array_ops.concat(sample_t, axis=0), dtype=dtypes.int32)) - embedding_idcs.append( - math_ops.cast( - array_ops.concat(indices_t, axis=0), dtype=dtypes.int32)) - aggregation_weights.append( - math_ops.cast( - array_ops.concat(weights_t, axis=0), dtype=dtypes.float32)) - - return sample_idcs, embedding_idcs, aggregation_weights - - def get_activations(self): - """Get activations for features. - - This should be called within `computation` that is passed to - `tpu.replicate` and friends. - - Returns: - A dictionary mapping from `String` of feature name to `Tensor` - of activation. - """ - recv_activations = tpu_ops.recv_tpu_embedding_activations( - num_outputs=len(self._table_to_config_dict), - config=self._config_proto.SerializeToString()) - - activations = collections.OrderedDict() - for table_id, table in enumerate(self._table_to_features_dict): - features = self._table_to_features_dict[table] - for lookup_id, feature in enumerate(features): - start_row = lookup_id * self._batch_size_per_core - end_row = start_row + self._batch_size_per_core - activations[feature] = gen_tpu_ops.tpu_embedding_activations( - self._dummy_table_variables[table_id], - recv_activations[table_id][start_row:end_row, :], - table_id=table_id, - lookup_id=lookup_id) - return activations - - # TODO(shizhiw): Make `gradient_multiplier` per feature. Setting it to 0 would - # have the effect of `tf.stop_gradients()`. - # TODO(shizhiw): Consider alternative ways to capture gradients wrt embedding - # layer outputs to remove `_dummy_table_variables`, - # `_embedding_activation_grad` and `tpu_embedding_gradients_table_%d'. - def generate_send_gradients_op(self, gradient_multipliers=None): - """Retrieve gradients from collections and send them to TPU embedding. - - Args: - gradient_multipliers: None, or dict mapping table names to gradient - multiplier Tensors. - - Returns: - SendTPUEmbeddingGradients Op. - - Raises: - ValueError: If required gradients have not been defined. - RuntimeError: If `mode` is not `TRAINING`. - """ - if self._mode != TRAINING: - raise RuntimeError('Only in training mode gradients need to ' - 'be sent to TPU embedding; got mode {}.' - .format(self._mode)) - - g = ops.get_default_graph() - gradients = list() - for table_id, table in enumerate(self._table_to_config_dict): - table_gradients = g.get_collection( - 'tpu_embedding_gradients_table_%d' % table_id) - if any(gradient is None for gradient in table_gradients): - raise ValueError( - 'Table {}/{} has undefined gradients: this is probably because the ' - 'model asked TPUEmbedding to compute activations that were not ' - 'used.'.format(table_id, table)) - concat_table_grads = array_ops.concat(table_gradients, axis=0) - if gradient_multipliers is not None: - concat_table_grads *= gradient_multipliers[table.name] - gradients.append(concat_table_grads) - - return tpu_ops.send_tpu_embedding_gradients( - inputs=gradients, config=self.config_proto.SerializeToString()) - - -def _validate_table_to_config_dict(table_to_config_dict): - """Validate `table_to_config_dict`.""" - for k, v in six.iteritems(table_to_config_dict): - if not isinstance(v, TableConfig): - raise ValueError('Value of `table_to_config_dict` must be of type ' - '`TableConfig`, got {} for {}.'.format(type(v), k)) - - -def _validate_feature_to_table_dict(table_to_config_dict, - feature_to_table_dict): - """Validate `feature_to_table_dict`.""" - used_table_set = set(feature_to_table_dict.values()) - table_set = set(table_to_config_dict.keys()) - - unused_table_set = table_set - used_table_set - if unused_table_set: - raise ValueError('`table_to_config_dict` specifies table that is not ' - 'used in `feature_to_table_dict`: {}.' - .format(unused_table_set)) - - extra_table_set = used_table_set - table_set - if extra_table_set: - raise ValueError('`feature_to_table_dict` refers to a table that is not ' - 'specified in `table_to_config_dict`: {}.' - .format(extra_table_set)) - - -def _validate_batch_size(batch_size, num_cores): - if batch_size % num_cores: - raise ValueError('`batch_size` is not a multiple of number of ' - 'cores. `batch_size`={}, `_num_cores`={}.'.format( - batch_size, num_cores)) - - -def _validate_optimization_parameters(optimization_parameters): - if not isinstance(optimization_parameters, _OptimizationParameters): - raise ValueError('`optimization_parameters` must inherit from ' - '`_OptimizationPramaters`. ' - '`type(optimization_parameters)`={}'.format( - type(optimization_parameters))) - - -class _OptimizerHandler(object): - """Interface class for handling optimizer specific logic.""" - - def __init__(self, optimization_parameters): - self._optimization_parameters = optimization_parameters - - def set_optimization_parameters(self, table_descriptor): - raise NotImplementedError() - - def get_default_slot_variable_names(self, table): - raise NotImplementedError() - - def create_variables_and_ops(self, table, slot_variable_names, num_hosts, - table_config, table_variables): - raise NotImplementedError() - - -class _AdagradHandler(_OptimizerHandler): - """Handles Adagrad specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdagradHandler, self).__init__(optimization_parameters) - self._table_to_accumulator_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adagrad.SetInParent() - - def get_default_slot_variable_names(self, table): - return AdagradSlotVariableName('{}/{}'.format(table, 'Adagrad')) - - def create_variables_and_ops(self, table, slot_variable_names, num_hosts, - table_config, table_variables): - accumulator_initializer = init_ops.constant_initializer( - self._optimization_parameters.initial_accumulator) - accumulator_variables = _create_partitioned_variables( - name=slot_variable_names.accumulator, - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=accumulator_initializer) - slot_variables = AdagradSlotVariable(accumulator_variables) - - load_ops = [] - retrieve_ops = [] - for host_id, table_variable, accumulator_variable in (zip( - range(num_hosts), table_variables, accumulator_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adagrad_parameters( - parameters=table_variable, - accumulators=accumulator_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_accumulator = ( - tpu_ops.retrieve_tpu_embedding_adagrad_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(accumulator_variable, retrieved_accumulator)) - - load_ops.append(load_parameters_op) - retrieve_ops.append(retrieve_parameters_op) - return slot_variables, load_ops, retrieve_ops - - -class _AdamHandler(_OptimizerHandler): - """Handles Adam specific logic.""" - - def __init__(self, optimization_parameters): - super(_AdamHandler, self).__init__(optimization_parameters) - self._table_to_m_variables_dict = {} - self._table_to_v_variables_dict = {} - - def set_optimization_parameters(self, table_descriptor): - table_descriptor.optimization_parameters.adam.beta1 = ( - self._optimization_parameters.beta1) - table_descriptor.optimization_parameters.adam.beta2 = ( - self._optimization_parameters.beta2) - table_descriptor.optimization_parameters.adam.epsilon = ( - self._optimization_parameters.epsilon) - table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( - not self._optimization_parameters.lazy_adam) - table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( - self._optimization_parameters.sum_inside_sqrt) - - def get_default_slot_variable_names(self, table): - return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), - '{}/{}/v'.format(table, 'Adam')) - - def create_variables_and_ops(self, table, slot_variable_names, num_hosts, - table_config, table_variables): - m_initializer = init_ops.zeros_initializer() - m_variables = _create_partitioned_variables( - name=slot_variable_names.m, - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=m_initializer) - v_initializer = init_ops.zeros_initializer() - v_variables = _create_partitioned_variables( - name=slot_variable_names.v, - num_hosts=num_hosts, - vocabulary_size=table_config.vocabulary_size, - embedding_dimension=table_config.dimension, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - initializer=v_initializer) - slot_variables = AdamSlotVariables(m_variables, v_variables) - - load_ops = [] - retrieve_ops = [] - for host_id, table_variable, m_variable, v_variable in (zip( - range(num_hosts), table_variables, - m_variables, v_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops.load_tpu_embedding_adam_parameters( - parameters=table_variable, - momenta=m_variable, - velocities=v_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table, retrieved_m, retrieved_v = ( - tpu_ops.retrieve_tpu_embedding_adam_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table), - state_ops.assign(m_variable, retrieved_m), - state_ops.assign(v_variable, retrieved_v)) - - load_ops.append(load_parameters_op) - retrieve_ops.append(retrieve_parameters_op) - return slot_variables, load_ops, retrieve_ops - - -class _StochasticGradientDescentHandler(_OptimizerHandler): - """Handles stochastic gradient descent specific logic.""" - - def set_optimization_parameters(self, table_descriptor): - (table_descriptor.optimization_parameters.stochastic_gradient_descent - .SetInParent()) - - def get_default_slot_variable_names(self, table): - return None - - def create_variables_and_ops(self, table, slot_variable_names, num_hosts, - table_config, table_variables): - del table_config - - load_ops = [] - retrieve_ops = [] - for host_id, table_variable in (zip( - range(num_hosts), table_variables)): - with ops.colocate_with(table_variable): - load_parameters_op = ( - tpu_ops - .load_tpu_embedding_stochastic_gradient_descent_parameters( - parameters=table_variable, - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieved_table = ( - tpu_ops - .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( - table_name=table, - num_shards=num_hosts, - shard_id=host_id)) - retrieve_parameters_op = control_flow_ops.group( - state_ops.assign(table_variable, retrieved_table)) - - load_ops.append(load_parameters_op) - retrieve_ops.append(retrieve_parameters_op) - return None, load_ops, retrieve_ops - - -def _get_optimization_handler(optimization_parameters): - if isinstance(optimization_parameters, AdagradParameters): - return _AdagradHandler(optimization_parameters) - elif isinstance(optimization_parameters, AdamParameters): - return _AdamHandler(optimization_parameters) - elif isinstance(optimization_parameters, StochasticGradientDescentParameters): - return _StochasticGradientDescentHandler(optimization_parameters) - else: - return NotImplementedError() - - -def _create_ordered_dict(d): - """Create an OrderedDict from Dict.""" - return collections.OrderedDict((k, d[k]) for k in sorted(d)) - - -def _create_combiners(table_to_config_dict): - return [table_to_config_dict[t].combiner for t in table_to_config_dict] - - -def _create_table_to_features_dict(feature_to_table_dict): - """Create mapping from table to a list of its features.""" - table_to_features_dict_tmp = {} - for feature, table in six.iteritems(feature_to_table_dict): - if table in table_to_features_dict_tmp: - table_to_features_dict_tmp[table].append(feature) - else: - table_to_features_dict_tmp[table] = [feature] - - table_to_features_dict = collections.OrderedDict() - for table in sorted(table_to_features_dict_tmp): - table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) - return table_to_features_dict - - -def _create_device_fn(hosts): - """Create device_fn() to use with _create_partitioned_variables().""" - - def device_fn(op): - """Returns the `device` for `op`.""" - part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) - - if part_match: - idx = int(part_match.group(1)) - else: - raise RuntimeError('Internal Error: ' - 'Expected %s to contain /part_*.' % op.name) - - device = hosts[idx] - return device - - return device_fn - - -def _create_partitioned_variables(name, - num_hosts, - vocabulary_size, - embedding_dimension, - initializer, - collections=None): # pylint: disable=redefined-outer-name - """Creates ParitionedVariables based on `num_hosts` for `table`.""" - # TODO(shizhiw): automatically place embedding lookup elsewhere? - if vocabulary_size < num_hosts: - raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). ' - 'As TPU embedding is not optimized for small tables, ' - 'please consider other ways for this embedding lookup.') - - return list(variable_scope.get_variable( - name, - shape=(vocabulary_size, embedding_dimension), - partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), - dtype=dtypes.float32, - initializer=initializer, - collections=collections, - trainable=False)) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/python/training/mode_keys_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py similarity index 63% rename from tensorflow/python/training/mode_keys_test.py rename to tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py index c4435b7d4870ac1675a3f2f4d80def111dc85ae5..308adc77e9ad2d912d0461512655b55faa53da60 100644 --- a/tensorflow/python/training/mode_keys_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding_gradient.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `tf.train.ModeKeys.""" +"""Stub file to maintain backwards compatibility.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.platform import test -from tensorflow.python.training import mode_keys - - -class ModeKeysTest(test.TestCase): - - def testKeyEquality(self): - self.assertEqual(mode_keys.ModeKeys.PREDICT, 'predict') - self.assertEqual(mode_keys.ModeKeys.TRAIN, 'train') - self.assertEqual(mode_keys.ModeKeys.TEST, 'test') +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_embedding_gradient import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 6a3ed9bb79502505d64156c6a405d9d57ee83eb5..893118412e1363ce50416e6ef36692bc23d04179 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1,3733 +1,33 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPUEstimator class.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy -import os -import signal -import sys -import threading -import time - -import numpy as np -import six -from six.moves import queue as Queue # pylint: disable=redefined-builtin -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.ops import tpu_ordinal_selector_op -from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding -from tensorflow.contrib.tpu.python.tpu import error_handling -from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional -from tensorflow.contrib.tpu.python.tpu import session_support -from tensorflow.contrib.tpu.python.tpu import tensor_tracer -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_context -from tensorflow.contrib.tpu.python.tpu import tpu_feed -from tensorflow.contrib.tpu.python.tpu import training_loop -from tensorflow.contrib.tpu.python.tpu import util as util_lib -from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import AdamParameters # pylint: disable=unused-import -from tensorflow.contrib.tpu.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec # pylint: disable=unused-import -from tensorflow.contrib.training.python.training import hparam -from tensorflow.core.framework import variable_pb2 -from tensorflow.core.framework.summary_pb2 import Summary -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.estimator import estimator as estimator_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export_output as export_output_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops_v2 as contrib_summary -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import tag_constants -from tensorflow.python.summary import summary -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import evaluation -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training -from tensorflow.python.training import training_util -from tensorflow.python.util import function_utils -from tensorflow.python.util import nest -from tensorflow.python.util import tf_inspect - -_INITIAL_LOSS = 1e7 -_ZERO_LOSS = 0. -_TPU_ESTIMATOR = 'tpu_estimator' -_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' -_BATCH_SIZE_KEY = 'batch_size' -_CTX_KEY = 'context' -_USE_TPU_KEY = 'use_tpu' -_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' -_ONE_GIGABYTE = 1024 * 1024 * 1024 -_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' -_TPU_TRAIN_OP = '_tpu_train_op' -_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' -_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor' - -# Ideally _USE_TPU_KEY should be reserved as well. However there are already -# models that make use of this key, thus it can not be reserved now to prevent -# breakage. In the long run, we would like to mitigate this by migrating models -# off of using _USE_TPU_KEY. -_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] - -# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is -# only used for per-core based deployments. For per-host based pipelines, if a -# user returns a Dataset instance it will be automatically wrapped in a -# tf.while_loop (This can be disabled by returning features and labels -# explicitly). -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False - -ops.register_proto_function( - '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), - proto_type=variable_pb2.VariableDef, - to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access - from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access - - -def _is_iterable(obj): - """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" - try: - iter(obj) - return True - except TypeError: - return False - - -class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext): - - def AddOp(self, op): - if op.type in [ - 'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary', - 'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2' - ]: - raise ValueError('Use tf.contrib.summary inside of host_calls.') - - -def _create_global_step(graph): - graph = graph or ops.get_default_graph() - if training.get_global_step(graph) is not None: - raise ValueError('"global_step" already exists.') - # Create in proper graph and base name_scope. - with graph.as_default() as g, g.name_scope(None): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) - - -def _create_or_get_iterations_per_loop(): - """Creates or gets the iterations_per_loop variable. - - In TPUEstimator, the user provided computation, the model_fn, is wrapped - inside a tf.while_loop for peak performance. The iterations of the loop are - specified by this variable, which adjusts its value on the CPU after each TPU - program execution and before the next TPU execution. - - The purpose of using a variable, rather then a constant, is to allow - TPUEstimator adapt the TPU training iterations according to the final steps - specified by users. For example, if the user sets the iterations_per_loop as 4 - in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop - variable will have the following value before each TPU training. - - - 1-th TPU execution: iterations_per_loop = 4 - - 2-th TPU execution: iterations_per_loop = 4 - - 3-th TPU execution: iterations_per_loop = 2 - - As model_fn increases the global step once per train_op invocation, the global - step is 10 after all TPU executions, matching the steps=10 inputs passed in by - users. - - Returns: - A TF non-trainable resource variable. - - Raises: - RuntimeError: If multi iterations_per_loop variables were found. - """ - graph = ops.get_default_graph() - collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) - iter_vars = graph.get_collection(collection_name) - if len(iter_vars) == 1: - return iter_vars[0] - elif len(iter_vars) > 1: - raise RuntimeError('Multiple iterations_per_loop_var in collection.') - - with ops.colocate_with(training_util.get_global_step()): - with variable_scope.variable_scope( - _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], - use_resource=True) - - -def _sync_variables_ops(ctx): - """Create varriables synchronization ops. - - Gets the variables back from TPU nodes. This means the variables updated - by TPU will now be *synced* to host memory. - In BROADCAST mode, we skip this sync since the variables are ususally too - big to transmit via RPC. - - Args: - ctx: A `_InternalTPUContext` instance with mode. - - Returns: - A list of sync ops. - """ - - if not ctx.is_input_broadcast_with_iterators(): - return [ - array_ops.check_numerics(v.read_value(), - 'Gradient for %s is NaN' % v.name).op - for v in variables.trainable_variables() - ] - else: - return [control_flow_ops.no_op()] - - -def _increase_eval_step_op(iterations_per_loop): - """Returns an op to increase the eval step for TPU evaluation. - - Args: - iterations_per_loop: Tensor. The number of eval steps running in TPU system - before returning to CPU host for each `Session.run`. - - Returns: - An operation - """ - eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access - # Estimator evaluate increases 1 by default. So, we increase the difference. - return state_ops.assign_add( - eval_step, - math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), - use_locking=True) - - -def _extract_key_names(tensor_or_dict): - if isinstance(tensor_or_dict, dict): - return sorted(tensor_or_dict.keys()) - return [] - - -class _SIGNAL(object): - """Signal used to control the thread of infeed/outfeed. - - All preserved signals must be negative numbers. Positive numbers are used to - indicate the number of iterations for next training/evaluation loop. - """ - NEXT_BATCH = -1 - STOP = -2 - - -class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - - See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and - `export_outputs`. - - For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where - `metric_fn` runs on CPU to generate metrics and `tensors` represents the - `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. - To be precise, TPU evaluation expects a slightly different signature from the - `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a - dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. - The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The - `tensors` usually specify the model logits, which are transferred back from - TPU system to CPU host. All tensors must have be batch-major, i.e., the batch - size is the first dimension. Once all tensors are available at CPU host from - all shards, they are concatenated (on CPU) and passed as positional arguments - to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - a dict. `metric_fn` takes the `tensors` and returns a dict from metric string - name to the result of calling a metric function, namely a `(metric_tensor, - update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the - `eval_metrics`. - - `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This - function should not capture any Tensors in `model_fn`. - - `host_call` is a tuple of a `function` and a list or dictionary of `tensors` - to pass to that function and returns a list of Tensors. `host_call` currently - works for train() and evaluate(). The Tensors returned by the function is - executed on the CPU on every step, so there is communication overhead when - sending tensors from TPU to CPU. To reduce the overhead, try reducing the - size of the tensors. The `tensors` are concatenated along their major (batch) - dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with `tf.contrib.summary.create_file_writer`. - """ - - def __new__(cls, - mode, - predictions=None, - loss=None, - train_op=None, - eval_metrics=None, - export_outputs=None, - scaffold_fn=None, - host_call=None, - training_hooks=None, - evaluation_hooks=None, - prediction_hooks=None): - """Creates a validated `TPUEstimatorSpec` instance.""" - host_calls = {} - if eval_metrics is not None: - host_calls['eval_metrics'] = eval_metrics - if host_call is not None: - host_calls['host_call'] = host_call - _OutfeedHostCall.validate(host_calls) - - training_hooks = tuple(training_hooks or []) - evaluation_hooks = tuple(evaluation_hooks or []) - prediction_hooks = tuple(prediction_hooks or []) - - for hook in training_hooks + evaluation_hooks + prediction_hooks: - if not isinstance(hook, session_run_hook.SessionRunHook): - raise TypeError('All hooks must be SessionRunHook instances, given: {}' - .format(hook)) - - return super(TPUEstimatorSpec, cls).__new__( - cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn, - host_call=host_call, - training_hooks=training_hooks, - evaluation_hooks=evaluation_hooks, - prediction_hooks=prediction_hooks) - - def as_estimator_spec(self): - """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - host_calls = {} - if self.eval_metrics is not None: - host_calls['eval_metrics'] = self.eval_metrics - if self.host_call is not None: - host_calls['host_call'] = self.host_call - host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) - eval_metric_ops = None - if self.eval_metrics is not None: - eval_metric_ops = host_call_ret['eval_metrics'] - hooks = None - if self.host_call is not None: - hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - loss = self.loss - if tensor_tracer.TensorTracer.is_enabled() \ - and self.train_op is not None: - tt = tensor_tracer.TensorTracer() - (loss, tracing_calls) = tt.trace_cpu(ops.get_default_graph(), - loss, self.train_op) - tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls) - tracing_functions = tracing_call_ret.values() - if tracing_functions: - if hooks: - hooks.extend([_OutfeedHostCallHook(tracing_functions)]) - else: - hooks = [_OutfeedHostCallHook(tracing_functions)] - hooks = tuple(hooks or []) - scaffold = self.scaffold_fn() if self.scaffold_fn else None - return model_fn_lib.EstimatorSpec( - mode=self.mode, - predictions=self.predictions, - loss=loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs, - scaffold=scaffold, - training_hooks=self.training_hooks + hooks, - evaluation_hooks=self.evaluation_hooks + hooks, - prediction_hooks=self.prediction_hooks + hooks) - - -class _OpQueueContext(object): - """Manages work queue and thread for a infeed/outfeed thread.""" - - def __init__(self, name, target, args): - self._name = name - self._queue = Queue.Queue() - args = (self,) + args - self._thread = threading.Thread(name=name, target=target, args=args) - self._thread.daemon = True - self._thread.start() - - def stop(self): - self._queue.put(_SIGNAL.STOP) - - def send_next_batch_signal(self, iterations): - self._queue.put(iterations) - - def read_iteration_counts(self): - while True: - iterations = self._queue.get(block=True) - logging.debug('%s read iterations %s', self._name, iterations) - if iterations == _SIGNAL.STOP: - logging.info('%s received shutdown signal, stopping.', self._name) - return - yield iterations - - def join(self): - logging.info('Shutting down %s thread.', self._name) - self.stop() - self._thread.join() - - -class _OpSignalOnceQueueContext(_OpQueueContext): - """Manages work queue and thread for a infeed/outfeed thread. - - This subclass only signals once. - """ - - def __init__(self, name, target, args): - super(_OpSignalOnceQueueContext, self).__init__(name, target, args) - self._has_signaled = False - - def send_next_batch_signal(self, iterations): - if not self._has_signaled: - self._queue.put(iterations) - self._has_signaled = True - - -class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): - """A Session hook setting up the TPU initialization, infeed, and outfeed. - - This hook does two major things: - 1. initialize and shutdown TPU system. - 2. launch and join the threads for infeed enqueue and (optional) outfeed - dequeue. - """ - - def __init__(self, - ctx, - enqueue_ops, - dequeue_ops, - tpu_compile_op, - run_infeed_loop_on_coordinator=True, - rendezvous=None, - master=None, - session_config=None, - tpu_init_ops=None): - self._master_job = ctx.master_job - self._enqueue_ops = enqueue_ops - self._dequeue_ops = dequeue_ops - self._rendezvous = rendezvous - self._master = master - self._session_config = session_config - self._init_ops = list(tpu_init_ops or []) - if ctx.embedding_config is None: - self._embedding_layer_config = None - else: - self._embedding_layer_config = ( - ctx.embedding_config.tpu_embedding.config_proto) - self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator - self._initial_infeed_sleep_secs = ( - ctx.config.tpu_config.initial_infeed_sleep_secs) - - self._feed_error = None - self._finished = False - self._should_initialize_tpu = True - self._tpu_compile_op = tpu_compile_op - - def begin(self): - logging.info('TPU job name %s', self._master_job) - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - if self._should_initialize_tpu: - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] - else: - self._finalize_ops = [] - - summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() - self._init_ops.extend(summary_writer_init_ops) - # Get all the writer resources from the initializer, so we know what to - # flush. - for op in summary_writer_init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def _run_infeed(self, queue_ctx, session): - logging.info('Starting infeed thread controller.') - if self._initial_infeed_sleep_secs: - logging.info('Infeed thread sleeping for %d seconds.', - self._initial_infeed_sleep_secs) - time.sleep(self._initial_infeed_sleep_secs) - logging.info('Infeed thread starting after sleep') - - with self._rendezvous.catch_errors(source='infeed', session=session): - if self._run_infeed_loop_on_coordinator: - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(self._enqueue_ops) - else: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - logging.info('Infeed thread finished, shutting down.') - - def _run_outfeed(self, queue_ctx, session): - logging.info('Starting outfeed thread controller.') - with self._rendezvous.catch_errors(source='outfeed', session=session): - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) - session.run(self._dequeue_ops) - logging.info('Outfeed thread finished, shutting down.') - - def _create_infeed_controller(self, name, target, args): - return _OpQueueContext(name=name, target=target, args=args) - - def _assertCompilationSucceeded(self, result, coord): - proto = tpu_compilation_result.CompilationResultProto() - proto.ParseFromString(result) - if proto.status_error_message: - logging.error('Compilation failed: {}'.format(proto.status_error_message)) - coord.request_stop() - else: - logging.info('Compilation succeeded') - - def after_create_session(self, session, coord): - if self._should_initialize_tpu: - logging.info('Init TPU system') - start = time.time() - with ops.Graph().as_default(): - with tf_session.Session( - self._master, config=self._session_config) as sess: - sess.run( - tpu.initialize_system( - job=self._master_job, - embedding_config=self._embedding_layer_config)) - logging.info('Initialized TPU in %d seconds', time.time() - start) - - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - - if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1': - logging.info('Compiling user program: this may take a while...') - self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord) - - self._infeed_controller = self._create_infeed_controller( - name='InfeedController', target=self._run_infeed, args=(session,)) - - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) - - # Enable the worker watchdog to terminate workers on coordinator exit. - watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0')) - if watchdog_timeout > 0: - session_support.start_worker_watchdog(session, - shutdown_timeout=watchdog_timeout) - - def before_run(self, run_context): - self._feed_error = None - - iterations = run_context.session.run(self._iterations_per_loop_var) - - logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) - self._infeed_controller.send_next_batch_signal(iterations) - - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) - - def end(self, session): - self._finished = True - logging.info('Stop infeed thread controller') - self._infeed_controller.join() - self._rendezvous.record_done('infeed') - - logging.info('Stop output thread controller') - self._outfeed_controller.join() - self._rendezvous.record_done('outfeed') - - logging.info('Shutdown TPU system.') - session.run(self._finalize_ops) - - -class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - - def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op, - rendezvous=None, master=None, session_config=None): - super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, - enqueue_ops, - dequeue_ops, - tpu_compile_op=tpu_compile_op, - run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous, - master=master, - session_config=session_config) - - def _create_infeed_controller(self, name, target, args): - return _OpSignalOnceQueueContext(name=name, target=target, args=args) - - -class _TPUStopAtStepHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step. - - This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with - following differences for TPU training: - - 1. This hook sets the variable for iterations_per_loop, which is used by - `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. - As the hook execution order is not guaranteed, the variable update is - handled in `after_create_session` and `after_run` as - `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. - - 2. For each training loop (session.run), the global step could be increased - multiple times on TPU. The global step tensor value will be explicitly read - again in `after_run` to ensure the latest value is retrieved to avoid race - condition. - """ - - def __init__(self, iterations, num_steps=None, last_step=None): - """Initializes a `StopAtStepHook`. - - Args: - iterations: The number of iterations to run optimizer per training loop. - num_steps: Number of steps to execute. - last_step: Step after which to stop. - - Raises: - ValueError: If one of the arguments is invalid. - """ - if num_steps is None and last_step is None: - raise ValueError('One of num_steps or last_step must be specified.') - if num_steps is not None and last_step is not None: - raise ValueError('Only one of num_steps or last_step can be specified.') - self._num_steps = num_steps - self._last_step = last_step - self._iterations = iterations - - def _next_iterations(self, global_step, last_step): - gap = last_step - global_step - return min(gap, self._iterations) - - def begin(self): - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError('Global step should be created.') - - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - if self._last_step is None: - self._last_step = global_step + self._num_steps - - iterations = self._next_iterations(global_step, self._last_step) - - self._iterations_per_loop_var.load(iterations, session=session) - - def after_run(self, run_context, run_values): - # Global step cannot be retrieved via SessionRunArgs and before_run due to - # race condition. - global_step = run_context.session.run(self._global_step_tensor) - if global_step >= self._last_step: - run_context.request_stop() - else: - iterations = self._next_iterations(global_step, self._last_step) - self._iterations_per_loop_var.load( - iterations, session=run_context.session) - - -class _SetEvalIterationsHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step.""" - - def __init__(self, num_steps): - """Initializes a `_SetEvalIterationsHook`. - - Args: - num_steps: Number of steps to execute. - """ - self._num_steps = num_steps - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - self._iterations_per_loop_var.load(self._num_steps, session=session) - - -class _StoppingPredictHook(session_run_hook.SessionRunHook): - """Hook that requests stop according to the stopping signal in prediction.""" - - def __init__(self, scalar_stopping_signal): - self._scalar_stopping_signal = scalar_stopping_signal - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - # This is not necessary as we do not run infeed enqueue and outfeed dequeue - # in side threads for prediction model. But it makes the - # TPUInfeedOutfeedSessionHook prints nice message. - self._iterations_per_loop_var.load(1, session=session) - - def before_run(self, run_context): - return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) - - def after_run(self, run_context, run_values): - _ = run_context - scalar_stopping_signal = run_values.results - if _StopSignals.should_stop(scalar_stopping_signal): - # NOTE(xiejw): In prediction, stopping signals are inserted for each - # batch. And we append one more batch to signal the system it should stop. - # The data flow might look like - # - # batch 0: images, labels, stop = 0 (user provided) - # batch 1: images, labels, stop = 0 (user provided) - # ... - # batch 99: images, labels, stop = 0 (user provided) - # batch 100: images, labels, stop = 1 (TPUEstimator appended) - # - # where the final batch (id = 100) is appended by TPUEstimator, so we - # should drop it before returning the predictions to user. - # To achieve that, we throw the OutOfRangeError in after_run. Once - # Monitored Session sees this error in SessionRunHook.after_run, the - # "current" prediction, i.e., batch with id=100, will be discarded - # immediately - raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') - - -def generate_per_core_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, host_device, host_id): - """Generates infeed enqueue ops for per-core input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A fn returns enqueue_ops.""" - num_cores_per_host = ctx.num_of_cores_per_host - per_host_sharded_inputs = [] - for core_ordinal in range(num_cores_per_host): - with ops.name_scope('ordinal_%d' % (core_ordinal)): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=host_device, - invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - if inputs.is_dataset: - raise TypeError( - '`input_fn` returning `Dataset` is not yet supported in ' - 'per-Core input pipeline deployment yet. Please set ' - 'TPUConfig.per_host_input_for_training to True or return ' - '`features` and `labels` from `input_fn`') - features, labels = inputs.features_and_labels() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - per_host_sharded_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) - return per_host_enqueue_ops - - return enqueue_ops_fn, captured_infeed_queue - - -def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - if batch_axis is not None: - raise TypeError('For mode PREDICT, batch_axis is not supported yet.') - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A Fn returning the TPU infeed enqueue ops. - - By providing as a Fn, it can be invoked inside the tf.while_loop such that - the input pipeline for multiple iterations can be executed by one - Session.run call. - - Returns: - list of dict of ops. - """ - with ops.device(device): - num_of_replicas_per_host = ctx.num_of_replicas_per_host - # Convert user input to features and labels. If the user returns a - # dataset, it is initialized and the features and labels extracted via - # `dataset.iterator.get_next()` - features, labels = inputs.features_and_labels() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure(features, labels) - unsharded_tensor_list = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=batch_axis) - captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_of_replicas_per_host) - per_host_enqueue_ops = ( - infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function_impl)) - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_per_host_v2_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - - with ops.device(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device, invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if not is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' - 'input pipeline configuration.') - - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True, - num_invocations_per_step=ctx.num_of_replicas_per_host) - - dataset_initializer = inputs.dataset_initializer() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """Generates the per_host enqueue ops.""" - control_deps = [] - per_host_sharded_inputs = [] - sparse_features_list = [] - num_replicas_per_host = ctx.num_of_replicas_per_host - cached_signals = None - with ops.device(device): - if not inputs.is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for this mode.') - for _ in range(num_replicas_per_host): - # Use control dependencies to ensure a deterministic ordering. - with ops.control_dependencies(control_deps): - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - # All the replicas share the replica 0's stopping singal. - # This avoids inconsistent state among different model replcias. - if cached_signals: - signals['stopping'] = cached_signals['stopping'] - else: - cached_signals = signals - - features, labels, sparse_features = ( - _tpu_estimator_embedding.split_inputs(ctx, features, labels)) - sparse_features_list.append(sparse_features) - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - control_deps.extend(flattened_inputs) - per_host_sharded_inputs.append(flattened_inputs) - - if inputs_structure_recorder.flattened_input_dims: - input_partition_dims = inputs_structure_recorder.flattened_input_dims - if signals: - input_partition_dims += [None] * len(signals) - # pylint: disable=protected-access - infeed_queue = tpu_feed._PartitionedInfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0]), - host_id=host_id, - input_partition_dims=input_partition_dims, - device_assignment=ctx.device_assignment) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs) - else: - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl) - captured_infeed_queue.capture(infeed_queue) - - if ctx.embedding_config: - per_host_enqueue_ops.extend( - ctx.embedding_config.tpu_embedding.generate_enqueue_ops( - sparse_features_list)) - - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, - num_hosts): - """Generates infeed enqueue ops for one input_fn on all the hosts.""" - captured_infeed_queue = _CapturedObject() - dataset_initializer = None - device_0 = ctx.tpu_host_placement_function(host_id=0) - with ops.device(device_0): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device_0, invocation_index=0) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - dataset_initializer = inputs.dataset_initializer() - num_replicas_per_host = ctx.num_of_replicas_per_host - - def tpu_ordinal_function_impl(replica_id): - if ctx.device_assignment: - return ctx.device_assignment.tpu_ordinal(replica=replica_id) - else: - return replica_id % num_replicas_per_host - - def device_function_impl(replica_id): - return ctx.tpu_host_placement_function(replica_id=replica_id) - - def enqueue_ops_fn(): - """Generates enqueue ops for all the hosts.""" - broadcasted_inputs = [] - flattened_inputs = None # Cache result from input_fn. - signals = None - for host_id in xrange(num_hosts): - with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): - for _ in xrange(ctx.num_of_replicas_per_host): - # Note: input_fn is only called once at host 0 for the first replica. - # The features and labels returned from that invocation are - # broadcasted to other replicas(including the replicas on other - # hosts). - if flattened_inputs is None: - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - broadcasted_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(broadcasted_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - enqueue_ops = infeed_queue.generate_enqueue_ops( - broadcasted_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl, - placement_function=device_function_impl) - - if signals is None: - return enqueue_ops - else: - return { - 'ops': enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, dataset_initializer - - -class _InputPipeline(object): - """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. - - `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in - `_InternalTPUContext`, it invokes `input_fn` for all cores (usually - multi-host TPU training) or for one host (usually for single-host TPU - evaluation), and sends all `features` and `labels` returned by `input_fn` to - TPU infeed. For per-core invocation, `features` and `labels` are piped to - infeed directly, one tuple for each core. For per-host invocation, `features` - and `labels` are split at host (with respect to `batch_axis`) and piped to all - cores accordingly. - - In addition, flatten/unflatten are handled by `_InputPipeline` also. Model - inputs returned by the `input_fn` can have one of the following forms: - 1. features - 2. (features, labels) - 3. ((arbitrarily nested structure of features), labels) - - Internally, form 1 is reformed to `(features, None)` as features and labels - are passed separately to underlying methods. For TPU training, TPUEstimator - may expect multiple `features` and `labels` tuples one for each core. - - TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). Both `features` and `labels` can be any nested sturcture - supported by TF nest (namely, dict, tuples, namedtuples or any nested - structure of such of Tensors). `labels` could be `None` as well. - - These are flattened before they are passed to the infeed/outfeed library - as that expectes flattend lists. - """ - - class InputsStructureRecorder(object): - """The recorder to record inputs structure.""" - - def __init__(self, input_partition_dims=None): - # Holds the structure of inputs - self._feature_structure = {} - self._flattened_input_dims = None - - if input_partition_dims: - # This should have been validated in TPUConfig. - assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' - if len(input_partition_dims) == 2: - self._feature_dims, self._label_dims = input_partition_dims - else: - self._feature_dims = input_partition_dims[0] - self._label_dims = None - - assert self._feature_dims is not None, ('input_partition_dims[0] must ' - 'not be None') - else: - self._feature_dims = None - self._label_dims = None - - # Internal state. - self._initialized = False - - @property - def flattened_input_dims(self): - assert self._initialized, 'InputsStructureRecorder is not initialized.' - return self._flattened_input_dims - - def has_labels(self): - return 'labels' in self._feature_structure - - def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, - label_dims_names, label_names, has_labels): - """Flatten input dims with the same order as flattened input tensors.""" - flattened_input_dims = [] - if feature_dims_names: - # We need a fixed ordering for matching the tensors in features. - flattened_input_dims.extend( - [feature_dims[name] for name in feature_dims_names]) - else: - flattened_input_dims.append(feature_dims) - - if label_dims_names: - # We need a fixed ordering for matching the tensors in labels. - flattened_input_dims.extend( - [label_dims[name] for name in label_dims_names]) - else: - if label_names: - num_tensors_in_label = len(label_names) - else: - num_tensors_in_label = int(has_labels) - # Setting `None` in input_partition_dims[1] will apply `None` to - # all the tensors in labels, regardless of internal structure. - flattened_input_dims.extend([label_dims] * num_tensors_in_label) - - return flattened_input_dims - - def validate_and_record_structure(self, features, labels): - """Validates and records the structure of `features` and `labels`.""" - # Extract structure. - has_labels = labels is not None - feature_names = _extract_key_names(features) - label_names = _extract_key_names(labels) - - if not self._initialized: - # Record structure. - self._initialized = True - if self._feature_dims is not None: - feature_dims_names = _extract_key_names(self._feature_dims) - if feature_dims_names != feature_names: - raise ValueError( - 'TPUConfig.input_partition_dims[0] mismatched feature' - ' keys. Expected {}, got {}'.format(feature_names, - feature_dims_names)) - - label_dims_names = _extract_key_names(self._label_dims) - if self._label_dims is not None and label_dims_names != label_names: - raise ValueError( - 'TPUConfig.input_partition_dims[1] mismatched label' - ' keys. Expected {}, got {}'.format(label_names, - label_dims_names)) - - self._flattened_input_dims = self._flatten_input_dims( - self._feature_dims, feature_dims_names, self._label_dims, - label_dims_names, label_names, has_labels) - - def flatten_features_and_labels(self, features, labels, signals=None): - """Flattens the `features` and `labels` to a single tensor list.""" - self._feature_structure['features'] = features - if labels is not None: - self._feature_structure['labels'] = labels - if signals is not None: - self._feature_structure['signals'] = signals - return data_nest.flatten(self._feature_structure) - - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. - - Args: - flattened_inputs: Flattened inputs for each shard. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. - - Raises: - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. - """ - - unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, - flattened_inputs) - return _Inputs( - unflattened_inputs['features'], - unflattened_inputs.get('labels'), - signals=unflattened_inputs.get('signals')) - - def __init__(self, input_fn, batch_axis, ctx): - """Constructor. - - Args: - input_fn: input fn for train or eval. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. - ctx: A `_InternalTPUContext` instance with mode. - - Raises: - ValueError: If both `sharded_features` and `num_cores` are `None`. - """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( - ctx.input_partition_dims) - - self._sharded_per_core = ctx.is_input_sharded_per_core() - self._input_fn = input_fn - self._infeed_queue = None - self._ctx = ctx - self._batch_axis = batch_axis - - def generate_infeed_enqueue_ops_and_dequeue_fn(self): - """Generates infeed enqueue ops and dequeue_fn.""" - # While tf.while_loop is called, the body function, which invokes - # `enqueue_fn` passed in, is called to construct the graph. So, input_fn - # structure is recorded. - enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( - self._invoke_input_fn_and_record_structure()) - - self._validate_input_pipeline() - - def dequeue_fn(): - """dequeue_fn is used by TPU to retrieve the tensors.""" - # In the model-parallel case, both the host-side and device-side - # computations must agree on the core on which infeed takes place. We - # choose to perform infeed on logical core 0 of each replica. - values = self._infeed_queue.generate_dequeue_op(tpu_device=0) - # The unflatten process uses the structure information recorded above. - return self._inputs_structure_recorder.unflatten_features_and_labels( - values) - - return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) - - def _invoke_input_fn_and_record_structure(self): - """Deploys the input pipeline and record input structure.""" - enqueue_ops = [] - infeed_queues = [] - all_dataset_initializers = [] - num_hosts = self._ctx.num_hosts - tpu_host_placement_fn = self._ctx.tpu_host_placement_function - - run_infeed_loop_on_coordinator = True - - if self._sharded_per_core: - # Per-Core input pipeline deployment. - # Invoke input pipeline for each core and placed on the corresponding - # host. - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( - generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - host_device, host_id)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - run_infeed_loop_on_coordinator = False - enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - # Infeed_queue_getter must be called after enqueue_ops_fn is called. - infeed_queues.append(captured_infeed_queue.get()) - - elif self._ctx.is_input_broadcast_with_iterators(): - # Only calls input_fn in host 0. - host_device = tpu_host_placement_fn(host_id=0) - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, - self._inputs_structure_recorder, - num_hosts)) - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - else: - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.device(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - if self._ctx.is_input_per_host_with_iterators(): - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_v2_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, host_device, host_id)) - else: - enqueue_ops_fn, captured_infeed_queue, dataset_initializer = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, self._batch_axis, - host_device, host_id)) - - # NOTE(xiejw): We dispatch here based on the return type of the - # users `input_fn`. - # - # 1. If input_fn returns a Dataset instance, we initialize the - # iterator outside of tf.while_loop, and call the iterator.get_next - # inside tf.while_loop. This should be always safe. - # - # 2. If input_fn returns (features, labels), it is too late to wrap - # them inside tf.while_loop, as resource initialization cannot be - # handled in TF control flow properly. In this case, we will use - # python loop to enqueue the data into TPU system. This may be - # slow compared to the previous case. - if dataset_initializer: - all_dataset_initializers.append(dataset_initializer) - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append( - wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops, [ - util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers) - ], run_infeed_loop_on_coordinator - - def _validate_input_pipeline(self): - """Validates the input pipeline. - - Perform some sanity checks to log user friendly information. We should - error out to give users better error message. But, if - _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - user code, so, log a warning. - - Raises: - RuntimeError: If the validation failed. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): - err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'It could be slow and not scalable. Please consider ' - 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/guide/datasets for ' - 'instructions.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) - - -def call_computation(computation, - experimental_exported_model_uses_all_cores=True): - """Call computation. - - computation uses a single-core for TPU inference. If - `experimental_exported_model_uses_all_cores` is `True`, this function will - round-robin - computation among all TPU cores visible to the host; otherwise, it will use - a single core. - - Args: - computation: A Python function that takes no inputs and builds computation - graph. If `computation` returns m outputs, this function will return a - list of m Tensors. - experimental_exported_model_uses_all_cores: Whether to round-robin among all - cores visible to the host, or to use a single core. - - Returns: - A list of output tensors. - """ - if experimental_exported_model_uses_all_cores: - # Using `TPUPartitionedCall` makes it possible to target a different - # TPU core with every `Session.run()` call. Note that the entire inference - # graph executes on a single core, and that invocations of this graph - # will round-robin among the cores attached to a host. - @function.Defun() - def tpu_subgraph(): - return computation() - - return tpu_functional.TPUPartitionedCall( - args=tpu_subgraph.captured_inputs, - device_ordinal=tpu_ordinal_selector_op.tpu_ordinal_selector(), - Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg], - f=tpu_subgraph) - else: - return computation() - - -class _ModelFnWrapper(object): - """A `model_fn` wrapper. - - This makes calling model_fn on CPU and TPU easier and more consistent and - performs necessary check and mutation required by TPU training and evaluation. - - In addition, this wrapper manages converting the `model_fn` to a single TPU - train and eval step. - """ - - def __init__(self, model_fn, config, params, ctx): - self._model_fn = model_fn - self._config = config - self._params = params - self._ctx = ctx - - def call_without_tpu(self, features, labels, is_export_mode): - return self._call_model_fn(features, labels, is_export_mode=is_export_mode) - - def _add_embedding_features(self, features): - if self._ctx.embedding_config: - tpu_embedding_ = self._ctx.embedding_config.tpu_embedding - embedding_activations = tpu_embedding_.get_activations() - features.update(embedding_activations) - - def convert_to_single_tpu_train_step(self, dequeue_fn): - """Converts user provided model_fn` as a single train step on TPU. - - The user provided `model_fn` takes input tuple - (features, labels) and produces the EstimatorSpec with train_op and loss for - train `mode`. This usually represents a single train computation on CPU. - - For TPU training, a train (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input should be taken from TPU infeed rather - than input pipeline (input_fn) directly. To fit TPU loop and replicate - pattern, the original train computation should be reformed, which is the - returned `train_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn - representing the train step for TPU. - """ - - host_call = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_training_hooks = _CapturedObject() - - def train_step(loss): - """Training step function for use inside a while loop.""" - del loss # unused; required in function signature. - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - self._add_embedding_features(features) - - estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels)) - loss, train_op = estimator_spec.loss, estimator_spec.train_op - - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - captured_scaffold_fn.capture(estimator_spec.scaffold_fn) - else: - captured_scaffold_fn.capture(None) - - captured_training_hooks.capture(estimator_spec.training_hooks) - - tracing_ops = [] - if tensor_tracer.TensorTracer.is_enabled(): - tt = tensor_tracer.TensorTracer() - loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), - loss, train_op, - self._ctx.num_replicas, - self._ctx.num_of_replicas_per_host, - self._ctx.num_hosts) - - if self._ctx.embedding_config is None: - apply_sparse_grads = [] - else: - tpu_embedding_ = self._ctx.embedding_config.tpu_embedding - apply_sparse_grads = [tpu_embedding_.generate_send_gradients_op()] - - # We must run train_op to update the variables prior to running the - # outfeed. - with ops.control_dependencies([train_op] + tracing_ops + - apply_sparse_grads): - host_call_outfeed_ops = [] - if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access - and estimator_spec.host_call is not None): - host_call.record({'host_call': estimator_spec.host_call}) - host_call_outfeed_ops = host_call.create_enqueue_op() - with ops.control_dependencies(host_call_outfeed_ops): - return array_ops.identity(loss) - - return (train_step, host_call, captured_scaffold_fn, - captured_training_hooks) - - def convert_to_single_tpu_eval_step(self, dequeue_fn): - """Converts user provided model_fn` as a single eval step on TPU. - - Similar to training, the user provided `model_fn` takes input tuple - (features, labels) and produces the TPUEstimatorSpec with eval_metrics for - eval `mode`. This usually represents a single evaluation computation on CPU. - - For TPU evaluation, a eval (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input and output are slightly different. Input, - features and labels, should be taken from TPU infeed rather than input - pipeline (input_fn) directly. Output is managed in two stages. First, the - model outputs as the result of evaluation computation, usually model logits, - should be transferred from TPU system to CPU. Then, all model outputs are - concatenated first on CPU and sent to the metric_fn for metrics computation. - To fit TPU evaluation pattern, the original eval computation should be - reformed, which is the returned `eval_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn - representing the eval step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_eval_hooks = _CapturedObject() - - def eval_step(total_loss): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - self._add_embedding_features(features) - - tpu_estimator_spec = self._call_model_fn(features, labels) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU evaluation must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - loss = tpu_estimator_spec.loss - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) - - to_record = {} - if tpu_estimator_spec.eval_metrics: - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics - if tpu_estimator_spec.host_call is not None: - # We assume that evaluate won't update global step, so we don't wrap - # this host_call. - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return math_ops.add(total_loss, loss) - - return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - - def convert_to_single_tpu_predict_step(self, dequeue_fn): - """Converts user provided model_fn` as a single predict step on TPU. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of predict_fn, host_calls, and captured scaffold_fn. The - predict_fn representing the predict step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_predict_hooks = _CapturedObject() - - def predict_step(unused_scalar_stopping_signal): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - stopping_signals = inputs.signals() - - assert stopping_signals is not None, ( - 'Internal Error: `signals` is missing.') - - tpu_estimator_spec = self._call_model_fn( - features, labels, is_export_mode=False) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU prediction must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) - - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) - to_record = {} - identity_fn = lambda **kwargs: kwargs - to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] - to_record['signals'] = [identity_fn, stopping_signals] - if tpu_estimator_spec.host_call is not None: - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return _StopSignals.as_scalar_stopping_signal(stopping_signals) - - return (predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks) - - def _verify_tpu_spec_predictions(self, predictions): - """Validates TPUEstimatorSpec.predictions dict.""" - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') - - for (key, tensor) in predictions.items(): - if tensor.shape.dims[0].value is None: - raise ValueError( - 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' - 'dynamic shape (should be static). Tensor: {}'.format(key, tensor)) - return predictions - - def _validate_model_features_and_labels(self, features, labels, - is_export_mode): - """Validates that the features and labels for the model function are valid. - - A valid features/labels object is the one with: - - Type: A tensor or any nested structure of tensors supported by TF nest, - namely nested dictionary, tuple, namedtuple, or sequence of tensors. - - Static shape if is_export_mode is False. - - Args: - features: the features that would be input to the model function. - labels: the labels that would be input to the model function. - is_export_mode: boolean value specifying if in export mode. - - Raises: - TypeError: If features/labels are not of the correct type. - ValueError: If features/labels have dynamic shape. - """ - - def validate(obj, obj_name): - """Helper validate function.""" - if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): - return - if isinstance(obj, ops.Tensor): - if not obj.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static shape.' - ' Tensor: {}'.format(obj_name, obj)) - else: - for tensor in data_nest.flatten(obj): - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - ('The {} to the model returned by input_fn must have static ' - 'shape. Tensor: {}').format(obj_name, tensor)) - - validate(features, 'features') - if labels is not None: - validate(labels, 'labels') - - def _call_model_fn(self, features, labels, is_export_mode=False): - """Calls the model_fn with required parameters.""" - self._validate_model_features_and_labels(features, labels, is_export_mode) - model_fn_args = function_utils.fn_args(self._model_fn) - kwargs = {} - - # Makes deep copy with `config` and params` in case user mutates them. - config = copy.deepcopy(self._config) - params = copy.deepcopy(self._params) - - if 'labels' in model_fn_args: - kwargs['labels'] = labels - elif labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') - if 'mode' in model_fn_args: - kwargs['mode'] = self._ctx.mode - if 'config' in model_fn_args: - kwargs['config'] = config - if 'params' in model_fn_args: - kwargs['params'] = params - - if 'params' not in model_fn_args: - raise ValueError('model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(self._model_fn)) - - if is_export_mode: - batch_size_for_model_fn = None - else: - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn - - if batch_size_for_model_fn is not None: - _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) - - running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) - _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) - - if not running_on_cpu: - user_context = tpu_context.TPUContext( - internal_ctx=self._ctx, call_from_input_fn=False) - _add_item_to_params(params, _CTX_KEY, user_context) - - estimator_spec = self._model_fn(features=features, **kwargs) - if (running_on_cpu and - isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access - # The estimator_spec will be passed to `Estimator` directly, which expects - # type `EstimatorSpec`. - return estimator_spec.as_estimator_spec() - else: - return estimator_spec - - def _verify_estimator_spec(self, estimator_spec): - """Validates the estimator_spec.""" - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - return estimator_spec - - err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' - if estimator_spec.training_chief_hooks: - raise ValueError( - err_msg.format('training_chief_hooks') + 'If you want' + - ' to pass training hooks, please pass via training_hooks.') - - if estimator_spec.scaffold: - logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' - 'Please use TPUEstimatorSpec.') - return estimator_spec - - -class _OutfeedHostCall(object): - """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" - - def __init__(self, ctx): - self._ctx = ctx - self._names = [] - # All of these are dictionaries of lists keyed on the name. - self._host_fns = {} - self._tensor_keys = collections.defaultdict(list) - self._tensors = collections.defaultdict(list) - self._tensor_dtypes = collections.defaultdict(list) - self._tensor_shapes = collections.defaultdict(list) - - @staticmethod - def validate(host_calls): - """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" - - for name, host_call in host_calls.items(): - if not isinstance(host_call, (tuple, list)): - raise ValueError('{} should be tuple or list'.format(name)) - if len(host_call) != 2: - raise ValueError('{} should have two elements.'.format(name)) - if not callable(host_call[0]): - raise TypeError('{}[0] should be callable.'.format(name)) - if not isinstance(host_call[1], (tuple, list, dict)): - raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) - - if isinstance(host_call[1], (tuple, list)): - fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = function_utils.fn_args(host_call[0]) - # wrapped_hostcall_with_global_step uses varargs, so we allow that. - if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' - 'method args of the function, which takes {}.'.format( - name, len(host_call[1]), len(fn_args))) - - @staticmethod - def create_cpu_hostcall(host_calls): - """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" - - _OutfeedHostCall.validate(host_calls) - ret = {} - for name, host_call in host_calls.items(): - host_fn, tensors = host_call - if isinstance(tensors, (tuple, list)): - ret[name] = host_fn(*tensors) - else: - # Must be dict. - try: - ret[name] = host_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise - return ret - - def record(self, host_calls): - """Records the host_call structure.""" - - for name, host_call in host_calls.items(): - host_fn, tensor_list_or_dict = host_call - self._names.append(name) - self._host_fns[name] = host_fn - - if isinstance(tensor_list_or_dict, dict): - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys[name].append(key) - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - else: - # List or tuple. - self._tensor_keys[name] = None - for tensor in tensor_list_or_dict: - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - - def create_enqueue_op(self): - """Create the op to enqueue the recorded host_calls. - - Returns: - A list of enqueue ops, which is empty if there are no host calls. - """ - if not self._names: - return [] - - tensors = [] - # TODO(jhseu): Consider deduping tensors. - for name in self._names: - tensors.extend(self._tensors[name]) - - with ops.device(tpu.core(0)): - return [tpu_ops.outfeed_enqueue_tuple(tensors)] - - def create_tpu_hostcall(self): - """Sends the tensors through outfeed and runs the host_fn on CPU. - - The tensors are concatenated along dimension 0 to form a global tensor - across all shards. The concatenated function is passed to the host_fn and - executed on the first host. - - Returns: - A dictionary mapping name to the return type of the host_call by that - name. - - Raises: - RuntimeError: If outfeed tensor is scalar. - """ - if not self._names: - return {} - - ret = {} - # For each i, dequeue_ops[i] is a list containing the tensors from all - # shards. This list is concatenated later. - dequeue_ops = [] - tensor_dtypes = [] - tensor_shapes = [] - for name in self._names: - for _ in self._tensors[name]: - dequeue_ops.append([]) - for dtype in self._tensor_dtypes[name]: - tensor_dtypes.append(dtype) - for shape in self._tensor_shapes[name]: - tensor_shapes.append(shape) - - # Outfeed ops execute on each replica's first logical core. Note: we must - # constraint it such that we have at most one outfeed dequeue and enqueue - # per replica. - for i in xrange(self._ctx.num_replicas): - host_device, ordinal_id = self._ctx.device_for_replica(i) - with ops.device(host_device): - outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, - shapes=tensor_shapes, - device_ordinal=ordinal_id) - for j, item in enumerate(outfeed_tensors): - dequeue_ops[j].append(item) - - # Deconstruct dequeue ops. - flat_dequeue_ops = [] - for l in dequeue_ops: - flat_dequeue_ops.extend(l) - - dequeue_ops_by_name = {} - pos = 0 - for name in self._names: - dequeue_ops_by_name[name] = dequeue_ops[pos:pos + - len(self._tensors[name])] - pos += len(self._tensors[name]) - - def _call_host_fn(fn, *args, **kw): - context = CatchInvalidHostcallFunctions() - context.Enter() - result = fn(*args, **kw) - context.Exit() - context.ExitResult(result) - return result - - # It is assumed evaluation always happens on single host TPU system. So, - # place all ops on tpu host if possible. - # - # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): - for name in self._names: - dequeue_ops = dequeue_ops_by_name[name] - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preserve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Make the specification of the outfeed combinaton - # function more explicit and well-documented. We may want to give the - # user the option of concatenating along any axis. - if (self._ctx.config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.BROADCAST): - # If the infeed is in BROADCAST mode (each core recieving the same - # input), then we assume that the cores also produce identical - # copies of the same output, and we simply take the output from - # the first core. This mode is used by Mesh-TensorFlow. - with ops.control_dependencies(dequeue_ops[i]): - dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0]) - else: - # Assume that the input has been batch-split and that axis 0 of the - # output tensors represents the batch size. Concatenate along - # the axis 0 to re-combine the batch. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) - - if self._tensor_keys[name] is not None: - # The user-provided eval_metrics[1] is a dict. - dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) - try: - ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise - else: - ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops) - - # force all dequeue operations to be run if not consumed by the host calls - ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops) - return ret - - -class _OutfeedHostCallHook(session_run_hook.SessionRunHook): - """Hook to run host calls when use_tpu=False.""" - - def __init__(self, tensors): - self._tensors = tensors - - def begin(self): - # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than - # create a separate hook to guarantee execution order, because summaries - # need to be initialized before the outfeed thread starts. - # TODO(jhseu): Make a wrapper hook instead? - self._init_ops = contrib_summary.summary_writer_initializer_op() - # Get all the writer resources from the initializer, so we know what to - # flush. - self._finalize_ops = [] - for op in self._init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def after_create_session(self, session, coord): - session.run(self._init_ops) - - def before_run(self, run_context): - return basic_session_run_hooks.SessionRunArgs(self._tensors) - - def end(self, session): - session.run(self._finalize_ops) - - -class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """Calculate and report global_step/sec and examples/sec during runtime.""" - - def __init__(self, - batch_size, - every_n_steps=100, - every_n_secs=None, - output_dir=None, - summary_writer=None): - self._batch_size = batch_size - super(ExamplesPerSecondHook, self).__init__( - every_n_steps=every_n_steps, - every_n_secs=every_n_secs, - output_dir=output_dir, - summary_writer=summary_writer) - - def _log_and_record(self, elapsed_steps, elapsed_time, global_step): - global_step_per_sec = elapsed_steps / elapsed_time - examples_per_sec = self._batch_size * global_step_per_sec - if self._summary_writer is not None: - global_step_summary = Summary(value=[ - Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) - ]) - example_summary = Summary(value=[ - Summary.Value(tag='examples/sec', simple_value=examples_per_sec) - ]) - self._summary_writer.add_summary(global_step_summary, global_step) - self._summary_writer.add_summary(example_summary, global_step) - logging.info('global_step/sec: %g', global_step_per_sec) - logging.info('examples/sec: %g', examples_per_sec) - - -class InstallSignalHandlerHook(session_run_hook.SessionRunHook): - """Change SIGINT (CTRL^C) handler to force quit the process. - - The default behavior often results in hanging processes. - The original handler is restored after training/evaluation. - """ - - def __init__(self): - self._signal_fn = signal.getsignal(signal.SIGINT) - - def before_run(self, run_context): - signal.signal(signal.SIGINT, signal.SIG_DFL) - - def end(self, session): - signal.signal(signal.SIGINT, self._signal_fn) - - -class TPUEstimator(estimator_lib.Estimator): - """Estimator with TPU support. - - TPUEstimator also supports training on CPU and GPU. You don't need to define - a separate `tf.estimator.Estimator`. - - TPUEstimator handles many of the details of running on TPU devices, such as - replicating inputs and models for each core, and returning to host - periodically to run hooks. - - TPUEstimator transforms a global batch size in params to a per-shard batch - size when calling the `input_fn` and `model_fn`. Users should specify - global batch size in constructor, and then get the batch size for each shard - in `input_fn` and `model_fn` by `params['batch_size']`. - - - For training, `model_fn` gets per-core batch size; `input_fn` may get - per-core or per-host batch size depending on `per_host_input_for_training` - in `TPUConfig` (See docstring for TPUConfig for details). - - - For evaluation and prediction, `model_fn` gets per-core batch size and - `input_fn` get per-host batch size. - - Evaluation - ========== - - `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` - for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return - `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case - the following discussion on TPU evaluation does not apply. - - `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where - `tensors` could be a list of any nested structure of `Tensor`s (See - `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns - a dict from metric string name to the result of calling a metric function, - namely a `(metric_tensor, update_op)` tuple. - - One can set `use_tpu` to `False` for testing. All training, evaluation, and - predict will be executed on CPU. `input_fn` and `model_fn` will receive - `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. - - Current limitations: - -------------------- - - 1. TPU evaluation only works on a single host (one TPU worker) except - BROADCAST mode. - - 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception - (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all - batches should have the same size. - - Example (MNIST): - ---------------- - - ``` - # The metric Fn which runs on CPU. - def metric_fn(labels, logits): - predictions = tf.argmax(logits, 1) - return { - 'accuracy': tf.metrics.precision( - labels=labels, predictions=predictions), - } - - # Your model Fn which runs on TPU (eval_metrics is list in this example) - def model_fn(features, labels, mode, config, params): - ... - logits = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, [labels, logits])) - - # or specify the eval_metrics tensors as dict. - def model_fn(features, labels, mode, config, params): - ... - final_layer_output = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, { - 'labels': labels, - 'logits': final_layer_output, - })) - ``` - - Prediction - ========== - - Prediction on TPU is an experimental feature to support large batch inference. - It is not designed for latency-critical system. In addition, due to some - usability issues, for prediction with small dataset, CPU `.predict`, i.e., - creating a new `TPUEstimator` instance with `use_tpu=False`, might be more - convenient. - - Note: In contrast to TPU training/evaluation, the `input_fn` for prediction - *should* raise an end-of-input exception (`OutOfRangeError` or - `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be - precise, the ops created by `input_fn` produce one batch of the data. - The `predict()` API processes one batch at a time. When reaching the end of - the data source, an end-of-input exception should be raised by one of these - operations. The user usually does not need to do this manually. As long as the - dataset is not repeated forever, the `tf.data` API will raise an end-of-input - exception automatically after the last batch has been produced. - - Note: Estimator.predict returns a Python generator. Please consume all the - data from the generator so that TPUEstimator can shutdown the TPU system - properly for user. - - Current limitations: - -------------------- - 1. TPU prediction only works on a single host (one TPU worker). - - 2. `input_fn` must return a `Dataset` instance rather than `features`. In - fact, .train() and .evaluate() also support Dataset as return value. - - Example (MNIST): - ---------------- - ``` - height = 32 - width = 32 - total_examples = 100 - - def predict_input_fn(params): - batch_size = params['batch_size'] - - images = tf.random_uniform( - [total_examples, height, width, 3], minval=-1, maxval=1) - - dataset = tf.data.Dataset.from_tensor_slices(images) - dataset = dataset.map(lambda images: {'image': images}) - - dataset = dataset.batch(batch_size) - return dataset - - def model_fn(features, labels, params, mode): - # Generate predictions, called 'output', from features['image'] - - if mode == tf.estimator.ModeKeys.PREDICT: - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - predictions={ - 'predictions': output, - 'is_padding': features['is_padding'] - }) - - tpu_est = TPUEstimator( - model_fn=model_fn, - ..., - predict_batch_size=16) - - # Fully consume the generator so that TPUEstimator can shutdown the TPU - # system. - for item in tpu_est.predict(input_fn=input_fn): - # Filter out item if the `is_padding` is 1. - # Process the 'predictions' - ``` - - Exporting - ========= - - `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, - and another with `tag_constants.SERVING` and `tag_constants.TPU`. - At serving time, these tags are used to select metagraph to load. - - Before running the graph on TPU, TPU system needs to be initialized. If - TensorFlow Serving model-server is used, this is done automatically. If - not, please call `session.run(tpu.initialize_system())`. - - `tpu.outside_compilation` can be used to wrap TPU incompatible ops in - `model_fn`. - - Example: - ---------------- - - ``` - def model_fn(features, labels, mode, config, params): - ... - logits = ... - export_outputs = { - 'logits': export_output_lib.PredictOutput( - {'logits': logits}) - } - - def host_call(logits): - class_ids = math_ops.argmax(logits) - classes = string_ops.as_string(class_ids) - export_outputs['classes'] = - export_output_lib.ClassificationOutput(classes=classes) - - tpu.outside_compilation(host_call, logits) - - ... - ``` - - """ - - def __init__(self, - model_fn=None, - model_dir=None, - config=None, - params=None, - use_tpu=True, - train_batch_size=None, - eval_batch_size=None, - predict_batch_size=None, - batch_axis=None, - eval_on_tpu=True, - export_to_tpu=True, - export_to_cpu=True, - warm_start_from=None, - experimental_exported_model_uses_all_cores=False, - experimental_export_device_assignment=False, - experimental_embedding_config_spec=None): - """Constructs an `TPUEstimator` instance. - - Args: - model_fn: Model function as required by `Estimator` which returns - EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', - and `prediction_hooks` must not capure any TPU Tensor inside the - model_fn. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. If `None`, the model_dir in - `config` will be used if set. If both are set, they must be same. If - both are `None`, a temporary directory will be used. - config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. - params: An optional `dict` of hyper parameters that will be passed into - `input_fn` and `model_fn`. Keys are names of parameters, values are - basic python types. There are reserved keys for `TPUEstimator`, - including 'batch_size'. - use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training and evaluation respect this bit, but eval_on_tpu can - override execution of eval. See below. - Predict still happens on CPU. - train_batch_size: An int representing the global training batch size. - TPUEstimator transforms this global batch size to a per-shard batch - size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. Must be divisible by total - number of replicas. - eval_batch_size: An int representing evaluation batch size. Must be - divisible by total number of replicas. - predict_batch_size: An int representing the prediction batch size. Must be - divisible by total number of replicas. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. For example, if your input_fn produced (images, labels) - where the images tensor is in `HWCN` format, your shard dimensions would - be [3, 0], where 3 corresponds to the `N` dimension of your images - Tensor, and 0 corresponds to the dimension along which to split the - labels to match up with the corresponding images. If None is supplied, - and per_host_input_for_training is True, batches will be sharded based - on the major dimension. If tpu_config.per_host_input_for_training is - False or `PER_HOST_V2`, batch_axis is ignored. - eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the - model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. - export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU. Note that unsupported export modes such as EVAL will be - ignored. For those modes, only a CPU model will be exported. - Currently, export_to_tpu only supports PREDICT. - export_to_cpu: If True, `export_savedmodel()` exports a metagraph for - serving on CPU. - warm_start_from: Optional string filepath to a checkpoint or SavedModel to - warm-start from, or a `tf.estimator.WarmStartSettings` object to fully - configure warm-starting. If the string filepath is provided instead of - a `WarmStartSettings`, then all variables are warm-started, and it is - assumed that vocabularies and Tensor names are unchanged. - experimental_exported_model_uses_all_cores: Whether to round-robin among - all cores visible to the host which is serving the saved model, or to - use a single core. This is a temporary flag to enable using all TPU - cores for inference with TPUPartitionedCall(). Once outside compilation - is supported in TPUPartitionedCall(), this flag will be enabled by - default. - experimental_export_device_assignment: Whether to include the device - assignment in the exported model. Doing so is useful in case of model - parallel inference but will tie the exported model to the TPU topology - used to export the model. - experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance - to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE - DO NOT USE. - - Raises: - ValueError: `params` has reserved keys already. - """ - if config is None or not isinstance(config, tpu_config.RunConfig): - raise ValueError( - '`config` must be provided with type `tpu_config.RunConfig`') - - if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): - raise ValueError('{} are reserved keys but existed in params {}.'.format( - _RESERVED_PARAMS_KEYS, params)) - - if use_tpu: - # Perform some very basic validations. More validations will be found in - # _InternalTPUContext. - if train_batch_size is None: - raise ValueError('`train_batch_size` cannot be `None`') - util_lib.check_positive_integer(train_batch_size, 'train_batch_size') - - if (config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.num_cores_per_replica): - raise ValueError( - 'Model parallelism only supports per host input for training. ' - 'Please adjust TPURunconfig.per_host_input_for_training.') - - if eval_batch_size is not None: - util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') - - if predict_batch_size is not None: - util_lib.check_positive_integer(predict_batch_size, - 'predict_batch_size') - - # Verifies the model_fn signature according to Estimator framework. - estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access - # We cannot store config and params in this constructor as parent - # constructor might change them, such as assigning a temp dir for - # config.model_dir. - model_function = self._augment_model_fn(model_fn, batch_axis) - - # Overwrite log_step_count_steps to disable TensorLoggingHook and - # StepCounterHook from being created in Estimator. TPUEstimator already - # added equivalent hooks in _augment_model_fn above. - self._log_every_n_steps = config.log_step_count_steps - config = config.replace(log_step_count_steps=None) - - # Passing non-None params as wrapped model_fn has it. - params = params or {} - super(TPUEstimator, self).__init__( - model_fn=model_function, - model_dir=model_dir, - config=config, - params=params, - warm_start_from=warm_start_from) - self._iterations_per_training_loop = ( - self._config.tpu_config.iterations_per_loop) - - # All properties passed to _InternalTPUContext are immutable. - # pylint: disable=protected-access - self._ctx = tpu_context._get_tpu_context( - self._config, train_batch_size, eval_batch_size, predict_batch_size, - use_tpu, eval_on_tpu, experimental_embedding_config_spec) - - self._export_to_cpu = export_to_cpu - self._export_to_tpu = export_to_tpu - self._experimental_exported_model_uses_all_cores = ( - experimental_exported_model_uses_all_cores) - self._experimental_export_device_assignment = ( - experimental_export_device_assignment) - if (experimental_exported_model_uses_all_cores and - experimental_export_device_assignment): - raise ValueError('experimental_exported_model_uses_all_cores and ' - 'experimental_export_device_assignment is not supported ' - 'at the same time.') - - self._is_input_fn_invoked = None - self._rendezvous = {} - - def _add_meta_graph_for_mode(self, - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None, - check_variables=True): - if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - logging.warning('TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; Mode {} will be ignored ' - 'for TPU.'.format(mode)) - - if not self._export_to_cpu and not self._export_to_tpu: - raise ValueError('One of export_to_cpu and export_to_tpu must be true.') - - if self._export_to_cpu: - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT: - input_receiver_fn_map = { - _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode] - } - export_tags = [tag_constants.SERVING, tag_constants.TPU] - mode = _REWRITE_FOR_INFERENCE_MODE - - # See b/110052256 for why `check_variables` is `False`. - if not self._export_to_cpu: - check_variables = save_variables = True - else: - check_variables = save_variables = False - (super(TPUEstimator, self)._add_meta_graph_for_mode( - builder, - input_receiver_fn_map, - checkpoint_path, - save_variables=save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - def _call_model_fn(self, features, labels, mode, config): - if mode == _REWRITE_FOR_INFERENCE_MODE: - return self._call_model_fn_for_inference(features, labels, mode, config) - else: - return super(TPUEstimator, self)._call_model_fn(features, labels, mode, - config) - - def _call_model_fn_for_inference(self, features, labels, mode, config): - """Wraps `_call_model_fn` for `export_savedmodel`.""" - if mode != _REWRITE_FOR_INFERENCE_MODE: - raise ValueError('mode must be {}; ' - 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) - - computation, capture = self._build_computation_for_inference( - features, labels, mode, config) - tensors = call_computation( - computation, - experimental_exported_model_uses_all_cores=self - ._experimental_exported_model_uses_all_cores) - estimator_spec, export_outputs_dict, predictions_dict, none_indices = ( - capture.get()) - predictions_list = tensors[:len(predictions_dict)] - export_outputs_list_without_none = tensors[len(predictions_dict):] - - # Reinsert `None`s which we've taken out in - # `_build_computation_for_inference()`. - export_outputs_list = [] - while none_indices or export_outputs_list_without_none: - if none_indices and none_indices[0] == len(export_outputs_list): - export_outputs_list.append(None) - none_indices.pop(0) - else: - export_outputs_list.append(export_outputs_list_without_none.pop(0)) - - # Reconstruct `export_outputs` with updated tensors. - new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict, - export_outputs_list) - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_export_outputs_dict)) - # Reconstruct `predictions` with updated tensors. - new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list) - if (len(new_predictions) == 1 and - _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions): - new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR] - - return estimator_spec._replace( - export_outputs=new_export_outputs, predictions=new_predictions) - - def _build_computation_for_inference(self, features, labels, mode, config): - capture = _CapturedObject() - - def computation(): - """Computation to be passed to `TPUPartitionedCall()`.""" - tpu_computation, tpu_capture = self._build_tpu_computation_for_inference( - features, labels, mode, config) - - if self._experimental_export_device_assignment: - # Export the device assignment as part of the model. This is useful for - # model parallel usecases where the model relies on the mapping between - # logical and physical devices. - with self._ctx.with_mode(mode) as ctx: - device_assignment = ctx.device_assignment - else: - device_assignment = None - tensors_on_cpu = tpu.rewrite_for_inference( - tpu_computation, device_assignment=device_assignment) - (estimator_spec, export_outputs_dict, export_outputs_list, - predictions_dict) = ( - tpu_capture.get()) - predictions_list = tensors_on_cpu[:len(predictions_dict)] - export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):] - - # Reconstruct tensors used in export_outputs, with TPU tensors replaced - # with their CPU counterpart returned from `rewrite_for_inference()`. - # `function.Defun()` does not like `None`s in return values, so we leave - # `None`s out but record their positions for later reconstruction. - export_outputs_list_without_none = [] - none_indices = [] - for i, t in enumerate(export_outputs_list): - if t is None: - none_indices.append(i) - else: - export_outputs_list_without_none.append( - export_outputs_tpu_on_cpu_list.pop(0)) - - capture.capture((estimator_spec, export_outputs_dict, predictions_dict, - none_indices)) - return predictions_list + export_outputs_list_without_none - - return computation, capture - - def _build_tpu_computation_for_inference(self, features, labels, mode, - config): - capture = _CapturedObject() - - def computation(): - """Compute tpu tensors used in export_outputs. - - Passed to rewrite_for_inference so that model_fn will be called under - the rewriting contexts. Only tpu tensors are returned, but export_outputs - and scaffold are captured. - - Returns: - A list of Tensors used in export_outputs and not marked for - outside_compilation. - """ - # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite_for_inference`. - mode = model_fn_lib.ModeKeys.PREDICT - estimator_spec = self._call_model_fn(features, labels, mode, config) - - # We pick the TPU tensors out from `export_output` and later return them - # from `computation` for rewriting. - export_outputs_dict = collections.OrderedDict( - (k, _export_output_to_tensors(v)) - for k, v in six.iteritems(estimator_spec.export_outputs)) - export_outputs_list = nest.flatten(export_outputs_dict) - export_outputs_tpu_list = [ - t for t in export_outputs_list if t is not None - ] - - if isinstance(estimator_spec.predictions, dict): - predictions_dict = collections.OrderedDict( - (k, v) for k, v in six.iteritems(estimator_spec.predictions)) - else: - predictions_dict = { - _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions - } - predictions_list = nest.flatten(predictions_dict) - - # We cannot return everything we want through the return values, so - # capture the rest here for later use. - capture.capture((estimator_spec, export_outputs_dict, export_outputs_list, - predictions_dict)) - return predictions_list + export_outputs_tpu_list - - return computation, capture - - def _create_global_step(self, graph): - """Creates a global step suitable for TPUs. - - Args: - graph: The graph in which to create the global step. - - Returns: - A global step `Tensor`. - - Raises: - ValueError: if the global step tensor is already defined. - """ - return _create_global_step(graph) - - def _convert_train_steps_to_hooks(self, steps, max_steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) - - # On TPU. - if steps is None and max_steps is None: - raise ValueError( - 'For TPU training, one of `steps` or `max_steps` must be set. ' - 'Cannot be both `None`.') - - # Estimator.train has explicit positiveness check. - if steps is not None: - util_lib.check_positive_integer(steps, 'Train steps') - if max_steps is not None: - util_lib.check_positive_integer(max_steps, 'Train max_steps') - - return [ - _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) - ] - - def _convert_eval_steps_to_hooks(self, steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) - - if steps is None: - raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - - util_lib.check_positive_integer(steps, 'Eval steps') - - return [ - evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access - num_evals=steps), - _SetEvalIterationsHook(steps) - ] - - def _call_input_fn(self, input_fn, mode): - """Calls the input function. - - Args: - input_fn: The input function. - mode: ModeKeys - - Returns: - In TPU mode, returns an input_fn to be called later in model_fn. - Otherwise, calls the input_fn and returns either fatures or - (features, labels). - - Raises: - ValueError: if input_fn takes invalid arguments or does not have `params`. - """ - input_fn_args = function_utils.fn_args(input_fn) - config = self.config # a deep copy. - kwargs = {} - if 'params' in input_fn_args: - kwargs['params'] = self.params # a deep copy. - else: - raise ValueError('input_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params["batch_size"]'.format(input_fn)) - if 'config' in input_fn_args: - kwargs['config'] = config - - if 'mode' in input_fn_args: - kwargs['mode'] = mode - - # Records the fact input_fn has been invoked. - self._is_input_fn_invoked = True - - with self._ctx.with_mode(mode) as ctx: - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - batch_size_for_input_fn = ctx.batch_size_for_input_fn - if batch_size_for_input_fn is not None: - _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY, - batch_size_for_input_fn) - - # For export_savedmodel, input_fn is never passed to Estimator. So, - # `is_export_mode` must be False. - if ctx.is_running_on_cpu(is_export_mode=False): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - # For TPU computation, input_fn should be invoked in a tf.while_loop for - # performance. While constructing the tf.while_loop, the structure of - # inputs returned by the `input_fn` needs to be recorded. The structure - # includes whether features or labels is dict or single Tensor, dict keys, - # tensor shapes, and dtypes. The recorded structure is used to create the - # infeed dequeue ops, which must be wrapped and passed as a Fn, called - # inside the TPU computation, as the TPU computation is wrapped inside a - # tf.while_loop also. So, we either pass input_fn to model_fn or pass - # dequeue_fn to model_fn. Here, `input_fn` is passed directly as - # `features` in `model_fn` signature. - def _input_fn(ctx): - _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) - return input_fn(**kwargs) - - return _input_fn - - def _validate_features_in_predict_input(self, result): - """Skip the validation. - - For TPUEstimator, we do not need to check the result type. `_InputPipeline` - has stronger check. Parent class's check generates confusing warning msg. - - Args: - result: `features` returned by input_fn. - """ - pass - - def train(self, - input_fn, - hooks=None, - steps=None, - max_steps=None, - saving_listeners=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous - try: - return super(TPUEstimator, self).train( - input_fn=input_fn, - hooks=hooks, - steps=steps, - max_steps=max_steps, - saving_listeners=saving_listeners) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('training_loop', sys.exc_info()) - finally: - rendezvous.record_done('training_loop') - rendezvous.raise_errors() - - def evaluate(self, - input_fn, - steps=None, - hooks=None, - checkpoint_path=None, - name=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous - try: - return super(TPUEstimator, self).evaluate( - input_fn, - steps=steps, - hooks=hooks, - checkpoint_path=checkpoint_path, - name=name) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('evaluation_loop', sys.exc_info()) - finally: - rendezvous.record_done('evaluation_loop') - rendezvous.raise_errors() - - def predict(self, - input_fn, - predict_keys=None, - hooks=None, - checkpoint_path=None, - yield_single_examples=True): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous - try: - for result in super(TPUEstimator, self).predict( - input_fn=input_fn, - predict_keys=predict_keys, - hooks=hooks, - checkpoint_path=checkpoint_path, - yield_single_examples=yield_single_examples): - yield result - except Exception: # pylint: disable=broad-except - rendezvous.record_error('prediction_loop', sys.exc_info()) - finally: - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - def _augment_model_fn(self, model_fn, batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - with self._ctx.with_mode(mode) as ctx: - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) - - # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, - # but not in `export_savedmodel()`. - if self._is_input_fn_invoked: - is_export_mode = False - else: - is_export_mode = True - - # Clear the bit. - self._is_input_fn_invoked = None - - # examples_hook is added to training_hooks for both CPU and TPU - # execution. - if self._log_every_n_steps is not None: - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - # pylint:disable=g-long-ternary - output_dir=(self.model_dir - if not config or config.save_summary_steps - else None), - # pylint:enable=g-long-ternary - every_n_steps=self._log_every_n_steps) - - if ctx.is_running_on_cpu(is_export_mode=is_export_mode): - logging.info('Running %s on CPU', mode) - estimator_spec = model_fn_wrapper.call_without_tpu( - features, labels, is_export_mode=is_export_mode) - if self._log_every_n_steps is not None: - estimator_spec = estimator_spec._replace( - training_hooks=estimator_spec.training_hooks + (examples_hook,)) - return estimator_spec - - assert labels is None, '`labels` passed to `model_fn` must be `None`.' - # TPUEstimator._call_input_fn passes `input_fn` as features to here. - assert callable(features), '`input_fn` is not callable.' - input_fn = features - - tpu_init_ops = [] - if ctx.embedding_config: - tpu_init_ops.extend(ctx.embedding_config.tpu_embedding.init_ops) - embedding_variables_and_ops = ( - ctx.embedding_config.tpu_embedding.create_variables_and_ops()) - tpu_init_ops.extend(embedding_variables_and_ops.load_ops) - - input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( - input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) - - graph = ops.get_default_graph() - for enqueue_op in enqueue_ops: - if isinstance(enqueue_op, list): - graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) - else: - graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) - - if mode == model_fn_lib.ModeKeys.TRAIN: - compile_op, loss, host_call, scaffold, training_hooks = ( - _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - host_ops = host_call.create_tpu_hostcall() - if host_ops is None: - host_ops = [] - - shutdown_hooks = [] - shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', - 'shutdown_worker') - if shutdown_mode: - if shutdown_mode == 'shutdown_worker': - finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=60 * 1000), - ] - elif shutdown_mode == 'shutdown_computation': - finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=60 * 1000), - ] - else: - raise ValueError( - 'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode) - - shutdown_hooks.append( - session_support.GracefulShutdownHook( - checkpoint_prefix=self.model_dir + '/model.ckpt', - on_shutdown_hooks=finalizer_hooks)) - - with ops.control_dependencies([loss]): - global_step = array_ops.identity(training.get_global_step()) - hooks = input_hooks + shutdown_hooks - hooks.extend([ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - host_ops, - tpu_compile_op=compile_op, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.master, - session_config=self._session_config, - tpu_init_ops=tpu_init_ops), - InstallSignalHandlerHook() - ]) - if self._log_every_n_steps is not None: - logging_hook_frequency = ( # Divide and round up - (self._log_every_n_steps + - self._config.tpu_config.iterations_per_loop - 1) // - self._config.tpu_config.iterations_per_loop) - hooks.append( - training.LoggingTensorHook({ - 'loss': array_ops.identity(loss), - 'step': global_step, - }, - every_n_iter=logging_hook_frequency)) - examples_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - hooks.append(examples_hook) - - if training_hooks: - hooks.extend(training_hooks) - - chief_hooks = [] - if (self._config.save_checkpoints_secs or - self._config.save_checkpoints_steps): - checkpoint_hook = training.CheckpointSaverHook( - self.model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) - checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - chief_hooks.append(checkpoint_hook) - - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops(ctx) - - if ctx.embedding_config: - update_ops.extend(embedding_variables_and_ops.retrieve_ops) - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - train_op = control_flow_ops.group(*update_ops) - graph.add_to_collection(_TPU_TRAIN_OP, train_op) - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_chief_hooks=chief_hooks, - training_hooks=hooks, - train_op=train_op, - scaffold=scaffold) - - if mode == model_fn_lib.ModeKeys.EVAL: - compile_op, total_loss, host_calls, scaffold, eval_hooks = ( - _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step - # counter properly - internal_ops_to_run = _sync_variables_ops(ctx) - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - - host_call_ret = host_calls.create_tpu_hostcall() - eval_metric_ops = {} - eval_update_ops = [] - - eval_metrics = host_call_ret.get('eval_metrics', {}) - if eval_metrics: - # Creates a dummy metric update_op for all metrics. Estimator - # expects all metrics in `eval_metric_ops` have update_op and calls - # them one by one. The real metric update_ops are invoked in a - # separated thread. So, here give Estimator the dummy op for all - # metrics. - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - for k, v in eval_metrics.items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) - else: - # If no eval metrics are passed, create an identity node for the - # loss and add `internal_ops_to_run` to its dependencies. So - # `internal_ops_to_run` can be executed. - with ops.control_dependencies(internal_ops_to_run): - mean_loss = array_ops.identity(mean_loss) - - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - hooks = [ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - eval_update_ops + host_ops, - tpu_compile_op=compile_op, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - master=self._config.evaluation_master, - session_config=self._session_config, - tpu_init_ops=tpu_init_ops) - ] + input_hooks - - if eval_hooks: - hooks.extend(eval_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, - scaffold=scaffold) - - # Predict - assert mode == model_fn_lib.ModeKeys.PREDICT - - (compile_op, dummy_predict_op, host_calls, - scaffold, prediction_hooks) = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - with ops.control_dependencies([dummy_predict_op]): - internal_ops_to_run = _sync_variables_ops(ctx) - with ops.control_dependencies(internal_ops_to_run): - dummy_predict_op = control_flow_ops.no_op() - - # In train and evaluation, the main TPU program is passed to monitored - # training session to run. Infeed enqueue and outfeed dequeue are - # executed in side threads. This is not the configuration for - # prediction mode. - # - # For prediction, the Estimator executes the EstimatorSpec.predictions - # directly and yield the element (via generator) to call site. So, the - # outfeed based prediction must be passed to MonitoredSession directly. - # Other parts of the TPU execution are organized as follows. - # - # 1. All outfeed based Tensors must be grouped with predictions Tensors - # to form a single invocation. This avoid the issue we might trigger - # multiple outfeeds incorrectly. To achieve this, `host_call` is - # placed in control_dependencies of `stopping_signals`, and - # `stopping_signals` is passed into _StoppingPredictHook, which sets - # the `stopping_signals` as SessionRunArgs. MonitoredSession merges - # all SessionRunArgs with the fetch in session.run together. - # - # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) - # are grouped together. They will be launched once and only once in - # side threads and they quit naturally according to the SAME stopping - # condition. - enqueue_ops.append(dummy_predict_op) - - host_call_ret = host_calls.create_tpu_hostcall() - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - - predictions = host_call_ret['predictions'] - _verify_cross_hosts_transfer_size( - predictions, - message=( - 'The estimated size for TPUEstimatorSpec.predictions is too ' - 'large.')) - signals = host_call_ret['signals'] - - with ops.control_dependencies(host_ops): - host_ops = [] # Empty, we do do not need it anymore. - scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - signals) - predictions = _PaddingSignals.slice_tensor_or_dict( - predictions, signals) - - hooks = [ - _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode], - tpu_compile_op=compile_op, - master=self._config.master, - session_config=self._session_config), - ] + input_hooks - - if prediction_hooks: - hooks.extend(prediction_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - prediction_hooks=hooks, - predictions=predictions, - scaffold=scaffold) - - return _model_fn - - -def _export_output_to_tensors(export_output): - """Get a list of `Tensors` used in `export_output`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - - Returns: - a list of tensors used in export_output. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - return [export_output.scores, export_output.classes] - elif isinstance(export_output, export_output_lib.RegressionOutput): - return [export_output.value] - elif isinstance(export_output, export_output_lib.PredictOutput): - return list(export_output.outputs.values()) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _clone_export_output_with_tensors(export_output, tensors): - """Clones `export_output` but with new `tensors`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - tensors: a list of `Tensors` used to construct a new `export_output`. - - Returns: - A dict similar to `export_output` but with `tensors`. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - if len(tensors) != 2: - raise ValueError('tensors must be of length 2; ' - 'got {}.'.format(len(tensors))) - return export_output_lib.ClassificationOutput(*tensors) - elif isinstance(export_output, export_output_lib.RegressionOutput): - if len(tensors) != 1: - raise ValueError('tensors must be of length 1; ' - 'got {}'.format(len(tensors))) - return export_output_lib.RegressionOutput(*tensors) - elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output_lib.PredictOutput( - dict(zip(export_output.outputs.keys(), tensors))) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) - - def multi_tpu_eval_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step, - [_ZERO_LOSS]) - - (compile_op, loss,) = tpu.split_compile_and_shard( - multi_tpu_eval_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - loss = loss[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get() - - -def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_train_step, host_call, captured_scaffold_fn, - captured_training_hooks) = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) - - def multi_tpu_train_steps_on_single_shard(): - return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step, - [_INITIAL_LOSS]) - - (compile_op, loss,) = tpu.split_compile_and_shard( - multi_tpu_train_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - loss = loss[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return compile_op, loss, host_call, scaffold, captured_training_hooks.get() - - -def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - (single_tpu_predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks - ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) - - def multi_tpu_predict_steps_on_single_shard(): - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - inputs = [_StopSignals.NON_STOPPING_SIGNAL] - outputs = training_loop.while_loop( - cond, single_tpu_predict_step, inputs=inputs, name=b'loop') - return outputs - - (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard( - multi_tpu_predict_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - dummy_predict_op = dummy_predict_op[0] - scaffold = _get_scaffold(captured_scaffold_fn) - return (compile_op, dummy_predict_op, host_calls, scaffold, - captured_predict_hooks.get()) - - -def _wrap_computation_in_while_loop(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def computation(i): - with ops.control_dependencies(op_fn()): - return i + 1 - - iterations_per_loop_var = _create_or_get_iterations_per_loop() - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - iterations = array_ops.identity(iterations_per_loop_var) - return control_flow_ops.while_loop( - lambda i: i < iterations, - computation, [constant_op.constant(0)], - parallel_iterations=1) - - -def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - def computation(unused_scalar_stopping_signal): - return_value = op_fn() - execute_ops = return_value['ops'] - signals = return_value['signals'] - with ops.control_dependencies(execute_ops): - return _StopSignals.as_scalar_stopping_signal(signals) - - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.device(device): - return control_flow_ops.while_loop( - cond, - computation, [_StopSignals.NON_STOPPING_SIGNAL], - parallel_iterations=1) - - -def _validate_tpu_training_graph(): - """Validate graph before running distributed training. - - Raises: - ValueError: If the graph seems invalid for running on device - """ - operations = ops.get_default_graph().get_operations() - - # Check if there is atleast one CrossReplicaSum operation in the graph - # This should be introduced by using the CrossShardOptimizer wrapper - cross_replica_sum_ops = [ - o for o in operations if o.type == _CROSS_REPLICA_SUM_OP - ] - if not cross_replica_sum_ops: - raise ValueError( - 'CrossShardOptimizer must be used for model training on TPUs.') - - -class _CapturedObject(object): - """A placeholder to capture an object. - - This is useful when we need to capture a Python object in the Tensorflow - control flow body function and use it outside the control flow. - """ - - def __init__(self): - self._object = None - self._captured = False - - def capture(self, o): - if self._captured: - raise RuntimeError( - 'InternalError: Object can capture only once. Please file bug.') - - self._captured = True - self._object = o - - def get(self): - if not self._captured: - raise RuntimeError( - 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug.') - return self._object - - -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" - with _CapturingContext(message='Inside scaffold_fn'): - scaffold_fn = captured_scaffold_fn.get() - if scaffold_fn: - scaffold = scaffold_fn() - if scaffold is None: - raise ValueError( - 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') - else: - scaffold = None - - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize - return scaffold - - -class _CapturingContext(control_flow_ops.ControlFlowContext): - """Tracks references to Tensors defined in TPU replication.""" - - def __init__(self, message): - control_flow_ops.ControlFlowContext.__init__(self) - self._message = message - - def to_control_flow_context_def(self, context_def, export_scope=None): - # pylint: disable=useless-super-delegation - # NOTE(slebedev): the method is required by `ControlFlowContext`. - super(_CapturingContext, self).to_control_flow_context_def( - context_def, export_scope) - - def AddOp(self, op): # pylint: disable=invalid-name - for c in op.inputs: - if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access - raise ValueError('{}: Op {} depends on TPU computation {}, ' - 'which is not allowed.'.format(self._message, op, c)) - - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._old = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - # pylint: enable=protected-access - - def __exit__(self, _, __, ___): # pylint: disable=invalid-name - self._g._set_control_flow_context(self._old) # pylint: disable=protected-access - - -class _Inputs(object): - """A data structure representing the input_fn returned values. - - This also supports the returned value from input_fn as `Dataset`. - """ - - def __init__(self, features=None, labels=None, dataset=None, signals=None): - if dataset is not None and (features is not None or labels is not None or - signals is not None): - raise RuntimeError('Internal Error: Either (features and labels) or ' - 'dataset should be provided, not both. Please file ' - 'bug') - - self._features = features - self._labels = labels - self._signals = signals - - self._dataset = dataset - self._iterator = None - - @staticmethod - def from_input_fn(return_values): - """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.DatasetV2): - dataset = return_values - return _Inputs(dataset=dataset) - - features, labels = _Inputs._parse_inputs(return_values) - return _Inputs(features, labels) - - @staticmethod - def _parse_inputs(return_values): - if isinstance(return_values, tuple): - features, labels = return_values - else: - features, labels = return_values, None - return features, labels - - @property - def is_dataset(self): - """Returns True if the return value from input_fn is Dataset.""" - return self._dataset is not None - - def dataset_initializer(self): - """Returns the dataset's initializer. - - The initializer must be run before calling `features_and_labels`. - """ - self._iterator = dataset_ops.make_initializable_iterator(self._dataset) - return self._iterator.initializer - - def features_and_labels(self): - """Gets `features` and `labels`.""" - if self.is_dataset: - if self._iterator is None: - raise RuntimeError('Internal error: Must run dataset_initializer ' - 'before calling features_and_labels(). Please file ' - 'a bug!') - return _Inputs._parse_inputs(self._iterator.get_next()) - - return (self._features, self._labels) - - def signals(self): - return self._signals - - @property - def dataset(self): - return self._dataset - - -class _InputsWithStoppingSignals(_Inputs): - """Inputs with `_StopSignals` inserted into the dataset.""" - - def __init__(self, - dataset, - batch_size, - add_padding=False, - num_invocations_per_step=1): - - assert dataset is not None - user_provided_dataset = dataset.map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size, add_padding=add_padding)) - if num_invocations_per_step == 1: - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - else: - # We append (2 * num_invocations_per_step - 1) batches for exhausting the - # user_provided_dataset and stop properly. - # For example, if num_invocations_per_step is 2, we append 3 additional - # padding batches: b1, b2, b3. - # If user_provided_dataset contains two batches: a1, a2 - # Step 1: [a1, a2] - # Step 2: [b1, b2] -> STOP - # If user_provided_dataset contains three batches: a1, a2, a3. - # The training loops: - # Step 1: [a1, a2] - # Step 2: [a3, b1] - # Step 3: [b2, b3] -> STOP. - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - final_batch_dataset = final_batch_dataset.repeat( - 2 * num_invocations_per_step - 1) - - def _set_mask(data_dict): - signals = data_dict['signals'] - signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) - data_dict['signals'] = signals - return data_dict - - # Mask out the extra batch. - final_batch_dataset = final_batch_dataset.map(_set_mask) - - dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) - - super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) - self._current_inputs = None - - def features_and_labels(self): - if self._current_inputs is not None: - raise RuntimeError( - 'Internal Error: The previous inputs have not been properly ' - 'consumed. First call features_and_labels, then call signals.') - - inputs_with_signals = self._iterator.get_next() - features = inputs_with_signals['features'] - labels = inputs_with_signals.get('labels') - - self._current_inputs = inputs_with_signals - return features, labels - - def signals(self): - """Returns the `Signals` from `_Inputs`.""" - if self._current_inputs is None: - raise RuntimeError( - 'Internal Error: The current inputs have not been properly ' - 'generated. First call features_and_labels, then call signals.') - signals = self._current_inputs['signals'] - self._current_inputs = None - return signals - - @staticmethod - def insert_stopping_signal(stop, batch_size, add_padding=False): - """Inserts stopping_signal into dataset via _map_fn. - - Here we change the data structure in the dataset, such that the return value - is a dictionary now and `features`, `labels`, and `signals` are three - distinguished keys in that dict. This provides a better structure, which - eases the process to decompose the inputs (see `features_and_labels`). - - Args: - stop: bool, state of current stopping signals. - batch_size: int, batch size. - add_padding: bool, whether to pad the tensor to full batch size. - - Returns: - A map_fn passed to dataset.map API. - """ - - def _map_fn(*args): - """The map fn to insert signals.""" - if len(args) == 1: - # Unpack the single Tensor/dict argument as features. This is required - # for the input_fn returns no labels. - args = args[0] - features, labels = _Inputs._parse_inputs(args) - new_input_dict = {} - - if add_padding: - padding_mask, features, labels = ( - _PaddingSignals.pad_features_and_labels(features, labels, - batch_size)) - - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - - else: - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - padding_mask = None - - new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size, - padding_mask=padding_mask).as_dict() - - return new_input_dict - - return _map_fn - - -class _StopSignals(object): - """Signals class holding all logic to handle TPU stopping condition.""" - - NON_STOPPING_SIGNAL = False - STOPPING_SIGNAL = True - - def __init__(self, stop, batch_size, padding_mask=None): - self._stop = stop - self._batch_size = batch_size - self._padding_mask = padding_mask - - def as_dict(self): - """Returns the signals as Python dict.""" - shape = [self._batch_size, 1] - dtype = dtypes.bool - - if self._stop: - stopping = array_ops.ones(shape=shape, dtype=dtype) - else: - stopping = array_ops.zeros(shape=shape, dtype=dtype) - - signals = {'stopping': stopping} - if self._padding_mask is not None: - signals['padding_mask'] = self._padding_mask - return signals - - @staticmethod - def as_scalar_stopping_signal(signals): - return array_ops.identity(signals['stopping'][0][0]) - - @staticmethod - def should_stop(scalar_stopping_signal): - """Detects whether scalar_stopping_signal indicates stopping.""" - if isinstance(scalar_stopping_signal, ops.Tensor): - # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF - # way to express the bool check whether scalar_stopping_signal is True. - return math_ops.logical_and(scalar_stopping_signal, - _StopSignals.STOPPING_SIGNAL) - else: - # For non Tensor case, it is used in SessionRunHook. So, we cannot modify - # the graph anymore. Here, we use pure Python. - return bool(scalar_stopping_signal) - - -class _PaddingSignals(object): - """Signals class holding all logic to handle padding.""" - - @staticmethod - def pad_features_and_labels(features, labels, batch_size): - """Pads out the batch dimension of features and labels.""" - real_batch_size = array_ops.shape( - _PaddingSignals._find_any_tensor(features))[0] - - batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) - - check_greater = check_ops.assert_greater_equal( - batch_size_tensor, - real_batch_size, - data=(batch_size_tensor, real_batch_size), - message='The real batch size should not be greater than batch_size.') - - with ops.control_dependencies([check_greater]): - missing_count = batch_size_tensor - real_batch_size - - def pad_single_tensor(tensor): - """Pads out the batch dimension of a tensor to the complete batch_size.""" - rank = len(tensor.shape) - assert rank > 0 - padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) - padded_shape = (batch_size,) + tuple(tensor.shape[1:]) - padded_tensor = array_ops.pad(tensor, padding) - padded_tensor.set_shape(padded_shape) - return padded_tensor - - def nest_pad(tensor_or_dict): - return nest.map_structure(pad_single_tensor, tensor_or_dict) - - features = nest_pad(features) - if labels is not None: - labels = nest_pad(labels) - - padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count, - batch_size) - - return padding_mask, features, labels - - @staticmethod - def slice_tensor_or_dict(tensor_or_dict, signals): - """Slice the real Tensors according to padding mask in signals.""" - - padding_mask = signals['padding_mask'] - batch_size = array_ops.shape(padding_mask)[0] - - def verify_batch_size(tensor): - check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) - with ops.control_dependencies([check_batch_size]): - return array_ops.identity(tensor) - - def slice_single_tensor(tensor): - rank = len(tensor.shape) - assert rank > 0 - real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) - return verify_batch_size(tensor)[0:real_batch_size] - - # As we split the Tensors to all TPU cores and concat them back, it is - # important to ensure the real data is placed before padded ones, i.e., - # order is preserved. By that, the sliced padding mask should have all 0's. - # If this assertion failed, # the slice logic here would not hold. - sliced_padding_mask = slice_single_tensor(padding_mask) - assert_padding_mask = math_ops.equal( - math_ops.reduce_sum(sliced_padding_mask), 0) - - with ops.control_dependencies([assert_padding_mask]): - should_stop = _StopSignals.should_stop( - _StopSignals.as_scalar_stopping_signal(signals)) - - is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) - - def slice_fn(tensor): - # If the current batch is full batch or part of stopping signals, we do - # not need to slice to save performance. - return control_flow_ops.cond( - math_ops.logical_or(should_stop, is_full_batch), - (lambda: verify_batch_size(tensor)), - (lambda: slice_single_tensor(tensor))) - - return nest.map_structure(slice_fn, tensor_or_dict) - - @staticmethod - def _find_any_tensor(batch_features): - tensors = [ - x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor) - ] - if not tensors: - raise ValueError('Cannot find any Tensor in features dict.') - return tensors[0] - - @staticmethod - def _padding_mask(real_batch_size, missing_count, batch_size): - padding_mask = array_ops.concat([ - array_ops.zeros((real_batch_size,), dtype=dtypes.int32), - array_ops.ones((missing_count,), dtype=dtypes.int32) - ], - axis=0) - padding_mask.set_shape((batch_size,)) - return padding_mask - - -def _verify_cross_hosts_transfer_size(tensor_dict, message): - total_size = 0 - tensor_structure = {} - for key, tensor in tensor_dict.items(): - shape = tensor.shape - size = np.product(shape) * tensor.dtype.size - tensor_structure[key] = shape - total_size += size - if total_size >= _ONE_GIGABYTE: - raise ValueError( - '{} The transfer size is larger than the protobuf limit. Please ' - 'consider to use Tensors with smaller shapes or reduce batch ' - 'size. Given:\n' - '{}'.format( - message, '\n'.join([ - ' -- Key: {}, Shape: {}'.format(k, v) - for k, v in tensor_structure.items() - ]))) - - -def _add_item_to_params(params, key, value): - """Adds a new item into `params`.""" - if isinstance(params, hparam.HParams): - # For HParams, we need to use special API. - if key in params: - params.set_hparam(key, value) - else: - params.add_hparam(key, value) - else: - # Now params is Python dict. - params[key] = value - - -def export_estimator_savedmodel(estimator, - export_dir_base, - serving_input_receiver_fn, - assets_extra=None, - as_text=False, - checkpoint_path=None, - strip_default_attrs=False): - """Export `Estimator` trained model for TPU inference. - - Args: - estimator: `Estimator` with which model has been trained. - export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - serving_input_receiver_fn: A function that takes no argument and returns a - `ServingInputReceiver` or `TensorServingInputReceiver`. - assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. - checkpoint_path: The checkpoint path to export. If `None` (the default), - the most recent checkpoint found within the model directory is chosen. - strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. - - Returns: - The string path to the exported directory. - """ - # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use - # `estimator.config`. - config = tpu_config.RunConfig(model_dir=estimator.model_dir) - est = TPUEstimator( - estimator._model_fn, # pylint: disable=protected-access - config=config, - params=estimator.params, - use_tpu=True, - train_batch_size=2048, # Does not matter. - eval_batch_size=2048, # Does not matter. - ) - return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, - assets_extra, as_text, checkpoint_path, - strip_default_attrs) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_estimator import * +# used by tests +from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors +from tensorflow.python.tpu.tpu_estimator import _create_global_step +from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors +from tensorflow.python.tpu.tpu_estimator import _get_scaffold +from tensorflow.python.tpu.tpu_estimator import _Inputs +from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR +from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS +from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR +from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index d5957b7e8ec40b40c7af8822378cee6134ef0d0f..af2542ea85290170ce6a38223188c4f9b871f032 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -1,898 +1,25 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Helper library for handling infeed between hosts and TPUs. -""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.contrib.tpu.python.tpu import tpu_sharding - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.util import nest - - -class InfeedQueue(object): - """A helper object to build a device infeed queue. - - The InfeedQueue builds the host-side and device-side Ops to enqueue and - dequeue elements, respectively, and ensures that their types and - shapes match. - """ - - def __init__(self, - number_of_tuple_elements=None, - tuple_types=None, - tuple_shapes=None, - shard_dimensions=None, - name=None): - """Creates a new InfeedQueue with the given configuration. - - The configuration need not be fully specified at creation since it - can be modified subsequently by methods that set the values - explicitly or infer them from the shapes of inputs. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - tuple_types: if not None, a list of types of the elements of the queue. - tuple_shapes: if not None, a list of shapes of the elements of the queue. - shard_dimensions: if not None, a list of dimensions on which the - elements of the queue should be sharded during automatic - parallelization. - name: the name of the queue. - - Raises: - ValueError: if number_of_tuple_elements <= 0; or - number_of_tuple_arguments, tuple_types, tuple_shapes, and - shard_dimensions are all None; or the length of tuple_types, - tuple_shapes, or shard_dimensions is not equal to - number_of_tuple_elements; or any element of shard_dimensions - can't be converted to a Dimension. - TypeError: if any element of tuple_types or tuple_shapes can't - be converted to a dtype or TensorShape, respectively. - """ - self._frozen = False - self._generated_enqueue_ops = False - self._generated_dequeue_op = False - self._name = "InfeedQueue" if name is None else name - if number_of_tuple_elements is None: - if tuple_types is not None: - number_of_tuple_elements = len(tuple_types) - elif tuple_shapes is not None: - number_of_tuple_elements = len(tuple_shapes) - elif shard_dimensions is not None: - number_of_tuple_elements = len(shard_dimensions) - else: - raise ValueError( - "number of tuple elements cannot be inferred from InfeedQueue " - "constructor") - if number_of_tuple_elements <= 0: - raise ValueError("number_of_tuple_elements %d must be > 0" % - number_of_tuple_elements) - # Make an empty sharding policy for each tuple element. - self._sharding_policies = [ - tpu_sharding.ShardingPolicy() - for _ in xrange(number_of_tuple_elements) - ] - if tuple_types is not None: - self.set_tuple_types(tuple_types) - else: - self._tuple_types = None - if tuple_shapes is not None: - self.set_tuple_shapes(tuple_shapes) - else: - self._tuple_shapes = None - if shard_dimensions is not None: - self.set_shard_dimensions(shard_dimensions) - self._validate() - - def _validate(self): - """Checks that the configuration is self-consistent. - - Raises: - ValueError: if the shapes and sharding policies don't match. - """ - if self.tuple_shapes is not None: - for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): - # Raise an error if the policy is incompatible with the shape. - _ = policy.get_sharded_shape(shape) - - @property - def number_of_tuple_elements(self): - """Returns the number of InfeedQueue tuple elements.""" - return len(self._sharding_policies) - - @property - def tuple_types(self): - """Returns the types of the InfeedQueue tuple elements.""" - return self._tuple_types - - def set_tuple_types(self, tuple_types): - """Sets the type of each element of the queue. - - tuple_types must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a dtype. - - Args: - tuple_types: the types of each queue element. - - Raises: - ValueError: if tuple_types is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_types cannot be converted to a - dtype. - """ - if len(tuple_types) != self.number_of_tuple_elements: - raise ValueError("tuple_types is %s, but must be a list of length %d" % - (str(tuple_types), self.number_of_tuple_elements)) - if self._frozen: - for (frozen, updated) in zip(self._tuple_types, tuple_types): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible type. Frozen types are %s, updated types are %s" % ( - str(self._tuple_types), str(tuple_types))) - else: - try: - self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] - except (TypeError) as e: - raise TypeError( - "tuple_types is %s, but must be a list of elements each " - "convertible to dtype: got error %s" % (str(tuple_types), str(e))) - - @property - def tuple_shapes(self): - """Returns the shapes of the InfeedQueue tuple elements.""" - return self._tuple_shapes - - def set_tuple_shapes(self, tuple_shapes): - """Sets the shape of each element of the queue. - - tuple_shapes must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a TensorShape. - - Args: - tuple_shapes: the shapes of each queue element. - - Raises: - ValueError: if tuple_shapes is not of length - self.number_of_tuple_elements. - TypeError: if an element of tuple_shapes cannot be converted to - a TensorShape. - """ - if len(tuple_shapes) != self.number_of_tuple_elements: - raise ValueError("tuple_shapes is %s, but must be a list of length %d" % - (str(tuple_shapes), self.number_of_tuple_elements)) - try: - tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] - except (ValueError, TypeError) as e: - raise TypeError( - "tuple_shapes is %s, but must be a list of elements each " - "convertible to TensorShape: got error %s" % (str(tuple_shapes), - str(e))) - if self._frozen: - for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): - if frozen != updated: - raise ValueError( - "Trying to update InfeedQueue with frozen configuration with an " - "incompatible shape. Frozen shapes are %s, updated shapes are %s" - % (str(self._tuple_shapes), str(tuple_shapes))) - else: - self._tuple_shapes = tuple_shapes - self._validate() - - @property - def sharding_policies(self): - """Returns the sharding policies of the InfeedQueue tuple elements.""" - return self._sharding_policies - - @property - def shard_dimensions(self): - """Gets the shard dimension of each tuple element. - - Returns: - A list of length number_of_tuple_elements, where each list entry - is the shard dimension of that tuple element or None if the - shard dimension has not been set. - """ - # The number of shards is always the same for all the policies. - return [policy.shard_dimension for policy in self._sharding_policies] - - def set_shard_dimensions(self, shard_dimensions): - """Sets the shard_dimension of each element of the queue. - - shard_dimensions must be a list of length - self.number_of_tuple_elements, and each element must be - convertible to a Dimension compatible with self.tuple_shapes. - - Args: - shard_dimensions: the dimensions of each queue element. - - Raises: - ValueError: if shard_dimensions is not of length - self.number_of_tuple_elements; or an element of - shard_dimensions cannot be converted to a Dimension; or an - element of shard_dimensions is a Dimension that is out of - range for the corresponding tuple element shape. - """ - if len(shard_dimensions) != self.number_of_tuple_elements: - raise ValueError("shard_dimensions is %s, but must be a list of length %d" - % (str(shard_dimensions), - self.number_of_tuple_elements)) - for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): - policy.set_shard_dimension(dimension) - self._validate() - - @property - def number_of_shards(self): - """Gets the number of shards to use for the InfeedQueue. - - Returns: - Number of shards or None if the number of shards has not been set. - """ - # The number of shards is always the same for all the policies. - return self._sharding_policies[0].number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards to use for the InfeedQueue. - - Args: - number_of_shards: number of ways to shard the InfeedQueue. - - Raises: - ValueError: if number_of_shards is not > 0; or the policies have - been frozen and number_of_shards was already set to something - else. - """ - for policy in self._sharding_policies: - policy.set_number_of_shards(number_of_shards) - self._validate() - - def set_configuration_from_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of Tensors whose types and shapes are used - to set the queue configuration. - - Args: - input_tensors: list of Tensors of the same types and shapes as - the desired queue Tuple. - - Raises: - ValueError: if input_tensors is not a list of length - self.number_of_tuple_elements - """ - if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError("input_tensors is %s, but should be a list of %d Tensors" - % (str(input_tensors), self.number_of_tuple_elements)) - self.set_tuple_shapes([t.shape for t in input_tensors]) - self.set_tuple_types([t.dtype for t in input_tensors]) - - def set_configuration_from_sharded_input_tensors(self, input_tensors): - """Sets the shapes and types of the queue tuple elements. - - input_tensors is a list of lists of Tensors whose types and shapes are used - to set the queue configuration. The length of the outer list is the number - of shards required, and each inner list is the tuple of Tensors to use to - determine the types and shapes of the corresponding shard. This method - depends on the shard dimension, and calling it freezes the shard policy. - - Args: - input_tensors: list of lists of Tensors. The outer list length corresponds - to the desired number of shards, and each inner list is the size - and shape of the desired configuration of the corresponding shard. - - Raises: - ValueError: if any inner list is not a list of length - self.number_of_tuple_elements; or the inner lists do not combine to - form a consistent unsharded shape. - TypeError: if the types of the Tensors in the inner lists do not match. - """ - if not self._frozen: - # Unset the tuple shapes in case the configuration becomes - # transiently inconsistent. - self._tuple_shapes = None - number_of_shards = len(input_tensors) - self.set_number_of_shards(number_of_shards) - for t in input_tensors: - if len(t) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s but must be a list of lists, where each inner" - " list has length number_of_tuple_elements=%d" % ( - str(input_tensors), self.number_of_tuple_elements)) - # Transpose the inputs to make a list of shard shapes for each tuple - # element. - sharded_shapes = [[t[i].shape for t in input_tensors] - for i in xrange(self.number_of_tuple_elements)] - # For each tuple, get the unsharded shape using that tuple's policy. - unsharded_shapes = [ - policy.get_unsharded_shape(s) - for (policy, s) in zip(self._sharding_policies, sharded_shapes) - ] - self.set_tuple_shapes(unsharded_shapes) - for i in xrange(1, self.number_of_shards): - for (t1, t2) in zip(input_tensors[0], input_tensors[i]): - if t1.dtype != t2.dtype: - raise TypeError( - "types of the tuple elements of input_tensors %s are not " - "consistent" % str(input_tensors)) - self.set_tuple_types([t.dtype for t in input_tensors[0]]) - - def freeze(self): - """Freezes the InfeedQueue so it can no longer be modified. - - The configuration is implicitly frozen before any host-side or - device-side Ops are generated. The configuration cannot be frozen - until the types and shapes of the tuple elements have been set. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set. - """ - self._frozen = True - if self._tuple_types is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple types.") - if self._tuple_shapes is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for shape in self._tuple_shapes: - if shape.dims is None: - raise ValueError( - "Can't freeze an InfeedQueue without setting all tuple shapes.") - for policy in self._sharding_policies: - policy.freeze() - self._validate() - - def generate_dequeue_op(self, tpu_device=0): - """Generates the device-side Op to dequeue a tuple from the queue. - - Implicitly freezes the queue configuration if it is not already - frozen, which will raise errors if the shapes and types have not - been fully specified. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. If None, no explicit placement will be performed, and it is up - to the user to call this API from within a proper TPU device scope. - The XLA code will fail if the TPU dequeue instruction is not bound to - any device. - - Returns: - A list of Outputs corresponding to a shard of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - if tpu_device is not None: - with ops.device(tpu.core(tpu_device)): - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - else: - return tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - - def _generate_enqueue_op(self, - inputs, - name_prefix, - index, - device=None, - tpu_ordinal=-1): - """Generate a host-side Op to enqueue a tuple to the queue. - - If device is None the inputs are all required to have the same - device specification, and the enqueue Op is colocated with - inputs[0]. Otherwise the enqueue Op is placed on 'device'. - - Args: - inputs: a list of Tensors with the types and shapes of the tuple elements. - name_prefix: the base name for the Op. - index: the shard index, used to uniquify the Op name. - device: device to place the Op on, or None if it should be - colocated with the inputs. - tpu_ordinal: ordinal of the TPU device on the host to use for - infeed if device is a CPU device. Should be set to -1 if device - is a TPU device. - - Returns: - An Op corresponding to a shard of infeed enqueued at the host, - suitable for use within a replicated block. - - Raises: - ValueError: if device is None and inputs do not all have the - same device specification. - """ - full_name = "%s/%d" % (name_prefix, index) - shapes = [t.shape for t in inputs] - if device is None: - devices = [t.device for t in inputs] - for i in xrange(1, self.number_of_tuple_elements): - if devices[0] != devices[i]: - raise ValueError( - "input devices for shard %d are %s, but should all be the same" % - (index, str(devices))) - with ops.colocate_with(inputs[0]): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - else: - with ops.device(device): - return tpu_ops.infeed_enqueue_tuple( - inputs=inputs, - shapes=shapes, - name=full_name, - device_ordinal=tpu_ordinal) - - def generate_enqueue_ops(self, - sharded_inputs, - tpu_ordinal_function=None, - placement_function=None): - """Generates the host-side Ops to enqueue the shards of a tuple. - - sharded_inputs is a list, one for each shard, of lists of - Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed - shard 0 if the queue. Returns the host-side Ops that must be run to - enqueue the sharded tuple. The Op for shard i is colocated with the inputs - for shard i. - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of sharded_inputs, an error - will be raised. - - Args: - sharded_inputs: a list of lists of Tensors. The length of the outer list - determines the number of shards. Each inner list indicates the types - and shapes of the tuples in the corresponding shard. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. tpu_ordinal_function must be - set if the inputs are placed on CPU devices. - placement_function: if not None, a function that takes the shard index as - input and returns the host device where the enqueue op should be placed - on. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(sharded_inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - if tpu_ordinal_function is None: - tpu_ordinal_function = lambda index: -1 - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - tpu_ordinal=tpu_ordinal_function(index), - device=placement_function(index) if placement_function else None) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - # TODO(misard) Generalize this to the case of systems that don't - # have 8 devices per host, and figure out what to do with - # model-parallelism. - def _default_placement_function(self, index): - return "/task:%d/device:CPU:0" % (index / 8) - - def _default_ordinal_function(self, index): - return index % 8 - - # TODO(b/36470756) remove this from tutorials once we have a better story - # for automatic placement of input pipelines. - def split_inputs_and_generate_enqueue_ops(self, - inputs, - device_assignment=None, - placement_function=None, - tpu_ordinal_function=None): - """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. - - Generates the host-side Ops to enqueue a tuple. - - This method performs poorly because it takes an entire input on a single - host, splits it, and distributes it to all of the cores. It is present only - to simplify tutorial examples. - - inputs is a list of Tensors to use to feed the queue. Each input is split - into self.number_of_shards shards. Returns an Op for each shard to enqueue - the shard. The Op for shard i is placed on device placement_function(i). - - Implicitly freezes the queue configuration if it is not already - frozen. If the configuration has already been frozen, and is not - compatible with the types and shapes of inputs, an error - will be raised. - - Args: - inputs: a list of Tensors which indicates the types and shapes of the - queue tuple. - device_assignment: if not `None`, a TPU `DeviceAssignment`. If - device_assignment is not `None`, but `placement_function` and - `ordinal_function` are None, then `device_assignment` will be used to - place infeeds on the first k TPU shards, where k is the number of shards - in the queue. If all three are `None`, then default placement and - ordinal functions are used. - placement_function: if not None, a function that takes the shard - index as input and returns a device string indicating which - device the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - tpu_ordinal_function: if not None, a function that takes the - shard index as input and returns the ordinal of the TPU device - the shard's infeed should be placed on. If placement_function - and tpu_ordinal_function are None, inputs are sharded round-robin - across the devices in the system. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of inputs are not compatible with the frozen - configuration. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of inputs are not compatible with the frozen - configuration. - """ - if device_assignment is None: - if placement_function is None: - placement_function = self._default_placement_function - if tpu_ordinal_function is None: - tpu_ordinal_function = self._default_ordinal_function - else: - - def _placement_function_from_map(index): - return device_assignment.host_device(replica=index) - - def _ordinal_function_from_map(index): - return device_assignment.tpu_ordinal(replica=index) - - if placement_function is None: - placement_function = _placement_function_from_map - if tpu_ordinal_function is None: - tpu_ordinal_function = _ordinal_function_from_map - self.set_configuration_from_input_tensors(inputs) - self.freeze() - if self._generated_enqueue_ops: - raise ValueError("Can't generate two enqueue Ops from the same queue") - self._generated_enqueue_ops = True - split_name_prefix = "%s/split" % self._name - if self.number_of_shards == 1: - transposed_sharded_inputs = [[inp] for inp in inputs] - else: - - def split_fn(inp, num_shards, axis, name): - with ops.colocate_with(inp): - return array_ops.split(inp, num_shards, axis=axis, name=name) - - transposed_sharded_inputs = [ - split_fn( - inp, - self.number_of_shards, - axis=policy.shard_dimension, - name="%s/%d" % (split_name_prefix, index)) - for (inp, policy, index) in zip(inputs, self._sharding_policies, - xrange(self.number_of_tuple_elements)) - ] - sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] - for i in xrange(self.number_of_shards)] - name_prefix = "%s/enqueue" % self._name - return [ - self._generate_enqueue_op( - shard, - name_prefix, - index, - device=placement_function(index), - tpu_ordinal=tpu_ordinal_function(index)) - for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) - ] - - -class _PartitionedInfeedQueue(InfeedQueue): - """A helper object to build a device infeed queue with input partition. - - Args: - number_of_tuple_elements: the number of Tensors fed atomically through the - queue, must be present unless it can be inferred from other arguments. - device_assignment: A TPU `DeviceAssignment` which is used to place all the - partitions to different TPU infeed queues. - host_id: The id of the host machine. - input_partition_dims: A nested list/tuple of integers. Each inner - list/tuple describes how to partition the corresponding input tensor. - tuple_types: If not None, a list of types of the elements of the queue. - tuple_shapes: If not None, a list of shapes of the elements of the queue. - name: The name of the queue. - """ - - def __init__(self, - number_of_tuple_elements, - device_assignment, - host_id, - input_partition_dims=None, - tuple_types=None, - tuple_shapes=None, - name=None): - super(_PartitionedInfeedQueue, self).__init__( - number_of_tuple_elements=number_of_tuple_elements, - tuple_types=tuple_types, - tuple_shapes=None, - shard_dimensions=None, - name="PartitionedInfeedQueue" if name is None else name) - self._input_partition_dims = input_partition_dims - self._host_id = host_id - self._device_assignment = device_assignment - - def generate_dequeue_op(self, tpu_device=0): - """Generate TPU dequeue ops. - - Args: - tpu_device: The TPU device ordinal where the infeed instruction should be - placed. - - Returns: - A list of Outputs corresponding to a partition of infeed dequeued - into XLA, suitable for use within a replicated block. - - Raises: - ValueError: if the types or shapes of the tuple elements have not been - set; or if a dequeue op has already been generated. - """ - self.freeze() - if self._generated_dequeue_op: - raise ValueError("Can't generate two dequeue Ops from the same queue") - self._generated_dequeue_op = True - full_name = "%s/dequeue" % self._name - sharded_shapes = [ - policy.get_sharded_shape(shape) - for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) - ] - with ops.device(tpu.core(tpu_device)): - values = tpu_ops.infeed_dequeue_tuple( - dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) - return self._tag_sharding_attribute_for_dequeued_tensors( - values, self._input_partition_dims) - - def generate_enqueue_ops(self, per_host_sharded_inputs): - """Generates the host-side Ops to enqueue the partitioned inputs. - - per_host_sharded_inputs is a list, one for each replica, of lists of - Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed - replica i. - sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. - - For example, if sharded_inputs[i][j] is a 2-D Tensor: - [[A, B, C, D], - [E ,F, G, H]] - self._input_partition_dims[j] is [2, 4]. - - sharded_inputs[i][j] will be partitioned and flattened into: - [A, B, C, D, E, F, G, H] and fed into the logical core ids: - [0, 1, 2, 3, 4, 5, 6, 7] respectively. - - Args: - per_host_sharded_inputs: a list of lists of Tensors. The length of the - outer list determines the number of shards. Each inner list indicates - the types and shapes of the tuples in the corresponding shard. - - Returns: - A list of host-side Ops, one for each shard, that when executed together - will enqueue a full-size element of infeed. - - Raises: - ValueError: if the queue configuration has previously been frozen and the - shapes of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the shapes of the elements of sharded_inputs - don't form a consistent unsharded tuple; or if the elements of a tuple - have different device constraints; or if the partition dims are invalid. - TypeError: if the queue configuration has previously been frozen and the - types of the elements of sharded_inputs are not compatible with the - frozen configuration; or if the types of the elements of sharded_inputs - don't form a consistent unsharded tuple. - """ - self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) - number_of_replicas_per_host = len(per_host_sharded_inputs) - number_of_tuple_elements = len(per_host_sharded_inputs[0]) - - assert len(self._input_partition_dims) == number_of_tuple_elements - per_host_enqueue_ops = [] - - for replica_index in range(number_of_replicas_per_host): - flattened_inputs = per_host_sharded_inputs[replica_index] - inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, - self._input_partition_dims) - inputs_parted_iters = [ - iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in - zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) - ] - - for logical_core in xrange(self._device_assignment.num_cores_per_replica): - # Places different partitions to different logic cores. - replica_id = self._device_assignment.lookup_replicas( - self._host_id, logical_core)[replica_index] - ordinal = self._device_assignment.tpu_ordinal( - replica=replica_id, logical_core=logical_core) - infeed_inputs = [] - for it in inputs_parted_iters: - input_for_device = next(it, None) - if input_for_device is not None: - infeed_inputs.append(input_for_device) - - if infeed_inputs: - per_host_enqueue_ops.append( - tpu_ops.infeed_enqueue_tuple( - inputs=infeed_inputs, - shapes=[x.shape for x in infeed_inputs], - name="enqueue/replica_{0}/input_{1}".format( - replica_index, logical_core), - device_ordinal=ordinal)) - return per_host_enqueue_ops - - def _check_input_partition_dims(self, tensor, dims): - """Checks that input partition dims are valid for the `Tensor`. - - Args: - tensor: Input tensor for partitioning. - dims: 1-D np.array of the list of integer describes how to partition the - input tensor. - - Raises: - ValueError: If the tensor can't be partitioned by dims or the - num_cores_per_replica doesn't match the number of - partitions(dims.prod()). - """ - if (dims < 1).any(): - raise ValueError("All input partition dims must be >= 1.") - - # No partitioning, so don't perform further checks. - if dims.prod() == 1: - return - - if dims.prod() != self._device_assignment.num_cores_per_replica: - raise ValueError( - "The product of each input parition dim should equal to " - "num_cores_per_replica. (dim = {}, num_cores_per_replica " - "= {})".format(dims, self._device_assignment.num_cores_per_replica)) - if dims.shape[0] != tensor.shape.ndims: - raise ValueError( - "Input partition dims must have the same number of dimensions " - "as the `Tensor` to be partitioned. (tensor shape = {}, input " - "partition dims = {}).".format(tensor.shape.as_list(), dims)) - - tensor.shape.assert_is_fully_defined() - - def _partition_or_replicate_on_host(self, tensor, dims): - """Partitions or replicates the input tensor. - - The ops inside this function are placed on the host side. - - Args: - tensor: The input tensor which will be partioned or replicated. - dims: A list of integer describes how to partition the input tensor. - Returns: - An iterator of `Tensor`s or a list of partioned tensors. - """ - if dims is None: - return itertools.repeat(tensor) - dims = np.array(dims) - self._check_input_partition_dims(tensor, dims) - output = [tensor] - shape_list = np.array(tensor.shape.as_list()) - quotients, remainders = np.divmod(shape_list, dims) - for axis, (quotient, remainder, dim, original_size) in enumerate( - zip(quotients, remainders, dims, shape_list)): - if dim <= 1: - continue - if remainder > 0: - # For each dimension, when it cannot be evenly partitioned, XLA assumes - # tensors are partitioned in a greedy manner by using - # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims - # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] => - # [[(3, 4), (3, 4), (2, 4), (2, 2)], - # [(2, 4), (2, 4), (2, 4), (2, 2)]] - ceil_ratio = quotient + 1 - num_full_slots, left_over = np.divmod(original_size, ceil_ratio) - num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over] - if len(num_or_size_splits) < dim: - num_or_size_splits += [0] * (dim - len(num_or_size_splits)) - new_output = [] - for x in output: - new_output.append( - array_ops.split( - x, num_or_size_splits=num_or_size_splits, axis=axis)) - output = new_output - else: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output - - def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensor. - - Args: - tensor: The dequeued tensor on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same tensor with the xla_sharding attribute. - """ - if dims is None: - return xla_sharding.replicate(tensor) - elif np.prod(dims) == 1: - return xla_sharding.assign_device(tensor, 0) - else: - tile_assignment = np.arange(np.prod(dims)).reshape(dims) - return xla_sharding.tile( - tensor=tensor, - tile_assignment=tile_assignment) - - def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): - """Tags appropriate XLA sharding attribute to the dequeued tensors. - - Args: - dequeues: A list of dequeued tensors on TPU. - dims: A list of integer describes how the tensor is partitioned. - - Returns: - The same dequeues with appropriate xla_sharding attribute. - """ - nest.assert_shallow_structure(dequeues, dims) - return nest.map_structure_up_to( - dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, - dims) +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_feed import * +# used by tests +from tensorflow.python.tpu.tpu_feed import _PartitionedInfeedQueue +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py index 84d5967ea547f0c036f7c9aa936ac0c99c141304..f2755c6979c2e49dbc19b6800462949601811496 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py @@ -1,57 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for functions used during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib - - -class TpuContext(object): - """A context object holding state about the TPU computation being built.""" - - def __init__(self): - """Creates a new TpuContext.""" - self._number_of_shards = None - - @property - def number_of_shards(self): - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - self._number_of_shards = number_of_shards - - -# The Tpu context holds the number of shards when a sharded computation is -# being built, or None if no computation is being built. -_current_tpu_context = TpuContext() - - -@contextlib.contextmanager -def tpu_shard_context(number_of_shards): - if _current_tpu_context.number_of_shards is not None: - raise NotImplementedError("tpu_shard_context cannot be nested.") - try: - _current_tpu_context.set_number_of_shards(number_of_shards) - yield - finally: - _current_tpu_context.set_number_of_shards(None) - - -def get_tpu_context(): - return _current_tpu_context +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_function import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index 1e11de6421e360faf0b9ad573a84f9aecdf9c98f..ca58e78d7b342c7ca70400652d99092ccbecbbde 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -1,203 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Optimizer that implements cross-shard gradient reduction for TPU.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import ops -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import optimizer - - -class CrossShardOptimizer(optimizer.Optimizer): - """An optimizer that averages gradients across TPU shards.""" - - def __init__(self, - opt, - reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer", - group_assignment=None): - """Construct a new cross-shard optimizer. - - Args: - opt: An existing `Optimizer` to encapsulate. - reduction: The reduction to apply to the shard losses. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "CrossShardOptimizer". - group_assignment: Optional 2d int32 lists with shape - [num_groups, num_replicas_per_group] which describles how to apply - optimizer to subgroups. - - Raises: - ValueError: If reduction is not a valid cross-shard reduction. - """ - if reduction not in (losses.Reduction.SUM, losses.Reduction.MEAN): - raise ValueError("Unsupported reduction: %s." % reduction) - - super(CrossShardOptimizer, self).__init__(False, name) - self._opt = opt - self._reduction = reduction - self._group_assignment = group_assignment - - def _verify_and_get_subgroup_size(self, group_assignment, num_shards): - """Verify group_assignment and get the subgroup size". - - Args: - group_assignment: list of group ids for applying the optimizer - to subgroups. - num_shards: The number of TPU shards. - - Returns: - The size of one subgroup in group_assignment. - - Raises: - ValueError: If group_assignment is invalid. - """ - if not group_assignment: - return None - if not (isinstance(group_assignment, list) and - all(isinstance(i, list) for i in group_assignment)): - raise ValueError("group_assignment must be a list of list. Got {}".format( - group_assignment)) - - replica_ids = set() - for g in group_assignment: - for i in g: - replica_ids.add(i) - - if set(range(num_shards)) != replica_ids: - raise ValueError("group_assignment must be a permutation of range({0})." - " Got group_assignment={1}".format( - num_shards, group_assignment)) - - subgroup_size_list = [len(group) for group in group_assignment] - if all(subgroup_size_list[0] == size for size in subgroup_size_list): - return subgroup_size_list[0] - else: - raise ValueError("The size of each subgroup in group_assignment must " - "be equal. Got group_assignment={}".format( - self._group_assignment)) - - def compute_gradients(self, loss, var_list=None, **kwargs): - """Compute gradients of "loss" for the variables in "var_list". - - This simply wraps the compute_gradients() from the real optimizer. The - gradients will be aggregated in the apply_gradients() so that user can - modify the gradients like clipping with per replica global norm if needed. - The global norm with aggregated gradients can be bad as one replica's huge - gradients can hurt the gradients from other replicas. - - Args: - loss: A Tensor containing the value to minimize. - var_list: Optional list or tuple of `tf.Variable` to update to minimize - `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. - **kwargs: Keyword arguments for compute_gradients(). - - Returns: - A list of (gradient, variable) pairs. - - Raises: - ValueError: If not within a tpu_shard_context or group_assignment is - invalid. - """ - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "CrossShardOptimizer should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - - subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, - num_shards) - - if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - if self._group_assignment: - scale = 1.0 / subgroup_size - else: - scale = 1.0 / num_shards - loss *= scale - - return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) - - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Apply gradients to variables. - - Calls tpu_ops.cross_replica_sum() to sum gradient contributions across - replicas, and then applies the real optimizer. - - Args: - grads_and_vars: List of (gradient, variable) pairs as returned by - compute_gradients(). - global_step: Optional Variable to increment by one after the - variables have been updated. - name: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. - - Returns: - An `Operation` that applies the gradients. If `global_step` was not None, - that operation also increments `global_step`. - - Raises: - ValueError: If the grads_and_vars is malformed. - """ - summed_grads_and_vars = [] - for (grad, var) in grads_and_vars: - if grad is None: - summed_grads_and_vars.append((grad, var)) - else: - with ops.colocate_with(grad): - summed_grads_and_vars.append((tpu_ops.cross_replica_sum( - grad, self._group_assignment), var)) - return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) - - def get_slot(self, *args, **kwargs): - """Return a slot named "name" created for "var" by the Optimizer. - - This simply wraps the get_slot() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - The `Variable` for the slot if it was created, `None` otherwise. - """ - return self._opt.get_slot(*args, **kwargs) - - def get_slot_names(self, *args, **kwargs): - """Return a list of the names of slots created by the `Optimizer`. - - This simply wraps the get_slot_names() from the actual optimizer. - - Args: - *args: Arguments for get_slot(). - **kwargs: Keyword arguments for get_slot(). - - Returns: - A list of strings. - """ - return self._opt.get_slot_names(*args, **kwargs) - - def variables(self): - """Forwarding the variables from the underlying optimizer.""" - return self._opt.variables() +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_optimizer import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index f5af03f33ca8f13af517007672e9ce0e12be6205..93c52335a582e5fa83092f78212ca268079b7c12 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -1,253 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Helper library for sharding during TPU compilation.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensorflow.python.framework import tensor_shape - -_DEFAULT_NUMBER_OF_SHARDS = 1 -_DEFAULT_SHARD_DIMENSION = 0 - - -# TODO(b/36777903) change other parts of tpu.py to use this class. -class ShardingPolicy(object): - """An object use to hold the sharding policy for a Tensor. - """ - - def __init__(self): - self._number_of_shards = None - self._shard_dimension = None - self._frozen = False - - def __str__(self): - if self.number_of_shards is None or self.shard_dimension is None: - return "ShardingPolicy(unset)" - else: - return ("ShardingPolicy(%d shards dimension %d)" % - (self.number_of_shards, self.shard_dimension)) - - def _fill_default_values(self): - if self._number_of_shards is None: - self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS - if self._shard_dimension is None: - self._shard_dimension = tensor_shape.as_dimension( - _DEFAULT_SHARD_DIMENSION) - - def freeze(self): - """Prevents further modification to the sharding policy. - - Any values that have not been set when freeze is called are set to - defaults. If the ShardingPolicy is already frozen, this is a NoOp. - """ - if not self._frozen: - self._fill_default_values() - self._frozen = True - - @property - def number_of_shards(self): - """Returns the number of shards in the policy or None if unspecified.""" - return self._number_of_shards - - def set_number_of_shards(self, number_of_shards): - """Sets the number of shards for the current policy. - - If the policy has been frozen then number_of_shards must match the - existing setting. - - Args: - number_of_shards: The number of shards to use in the policy. - - Raises: - ValueError: If the policy has been frozen and number_of_shards - differs from the frozen value; or number_of_shards <= 0. - """ - if self._frozen: - if self._number_of_shards != number_of_shards: - raise ValueError( - "Can't set sharding policy to use %d shards since it has been " - "frozen to use %d." % (number_of_shards, self._number_of_shards)) - else: - if number_of_shards > 0: - self._number_of_shards = number_of_shards - else: - raise ValueError( - "Can't set sharding policy to use %s shards; value must be >0", - str(number_of_shards)) - - @property - def shard_dimension(self): - """Returns the shard dimension of the policy or None if unspecified.""" - return self._shard_dimension - - def set_shard_dimension(self, shard_dimension): - """Sets the shard dimension for the current policy. - - If the policy has been frozen then shard_dimension must match the - existing setting. - - Args: - shard_dimension: The shard dimension to use in the policy. - - Raises: - ValueError: If the policy has been frozen and shard_dimension - differs from the frozen value, or shard_dimension can't be - interpreted as a Dimension. - """ - if self._frozen: - if self._shard_dimension != shard_dimension: - raise ValueError( - "Can't set shard dimension to %d since it has been frozen to " - "use %d." % (shard_dimension, self._shard_dimension)) - else: - self._shard_dimension = tensor_shape.as_dimension(shard_dimension) - - def merge(self, other): - """Merges the policy of another policy into the current policy. - - Args: - other: The policy to merge into this one. - - Raises: - ValueError: If this policy has been frozen and the merge conflicts with - the frozen policy. - """ - if other.number_of_shards is not None: - self.set_number_of_shards(other.number_of_shards) - if other.shard_dimension is not None: - self.set_shard_dimension(other.shard_dimension) - - def get_sharded_shape(self, shape, shard_index=None): - """Returns the shape of a shard of a full Tensor. - - When given the shape of a 'full-size' Tensor, returns the shape of - the sub-Tensor after it has been sharded. Freezes the policy if it - has not yet been frozen. - - Args: - shape: The shape of the full-size Tensor to be sharded. - shard_index: The index of the shard whose shape should be returned. - shard_index can be None for sharding policies that use the same - shape for every shard. - freeze_config: - - Returns: - The shape of the sharded version of the Tensor. - - Raises: - ValueError: If shard_index is None when shards are of different - shapes; or shard_index is not None and - !(0<=shard_index= self.number_of_shards: - raise ValueError("shard_index %d, but must be in [0,%d)." % - (shard_index, self._number_of_shards)) - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - if dims[self._shard_dimension] is None: - raise ValueError("shape %s must have a fixed size for dimension %d " - "that is known at graph construction time." % - (shape.as_list(), self._shard_dimension)) - if (dims[self._shard_dimension] % self._number_of_shards) != 0: - raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % - (shape.as_list(), self._number_of_shards, - self._shard_dimension)) - dims[self._shard_dimension] /= self._number_of_shards - return tensor_shape.as_shape(dims) - - def _unshard_shape(self, shape): - """Return the unsharded shape that would generate a given sharded shape. - - Args: - shape: the sharded shape to unshard - - Returns: - The unsharded shape. - - Raises: - ValueError: if shape is unknown or does not contain - self.shard_dimension - TypeError: if shape is not convertible to a TensorShape - """ - shape = tensor_shape.as_shape(shape) - if self._number_of_shards == 1: - # Don't do anything when there's only one shard. - return shape - ndims = shape.ndims - if ndims is None: - raise ValueError("shape must be a specified shape not Unknown") - if ndims <= self._shard_dimension: - raise ValueError("shape %s does not contain shard_dimension %d" % - (shape.as_list(), self._shard_dimension)) - dims = shape.as_list() - dims[self._shard_dimension] *= self._number_of_shards - return tensor_shape.as_shape(dims) - - def get_unsharded_shape(self, shapes): - """Returns the shape of an unsharded Tensor given a list of shards. - - When given a list of shapes of shards, returns the shape of the - unsharded Tensor that would generate the shards. Sets defaults for the - policy if number_of_shards or shard_dimension is None. - - Args: - shapes: The shapes of the Tensor shards to be combined. - - Returns: - The shape of the unsharded version of the Tensor. - - Raises: - ValueError: if shapes is not a list of length - self.number_of_shards; or any element of shapes is not a valid - shape consistent with the sharding policy; or the list of - shapes is not a valid sharding of a full shape. - TypeError: if an element of shapes is not convertible to a - TensorShape - """ - self._fill_default_values() - if len(shapes) != self.number_of_shards: - raise ValueError( - "shapes is %s but must be a list of length number_of_shards=%d" % ( - str(shapes), self.number_of_shards)) - unsharded_shapes = [self._unshard_shape(s) for s in shapes] - for i in xrange(self.number_of_shards - 1): - if not unsharded_shapes[i].is_compatible_with( - unsharded_shapes[self.number_of_shards - 1]): - raise ValueError( - "sharded shapes %s are not consistent shards of a full shape " - "sharded %d ways along dimension %d" % ( - str(shapes), self.number_of_shards, self.shard_dimension)) - return unsharded_shapes[0] +# pylint: disable=wildcard-import,unused-import,redefined-builtin +from tensorflow.python.tpu.tpu_sharding import * +# pylint: enable=wildcard-import,unused-import,redefined-builtin diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index d66ecfcf4a56b8da1c2d2f518bebe4baa76b315e..258d34ddaf5250e49c5a354caf018e4b64abae62 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -1,156 +1,25 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== -"""TPU system metadata and associated tooling.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import re - -from tensorflow.contrib.tpu.python.tpu import tpu -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging - -_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000 # 1 min -_RETRY_TIMES = 120 -_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins - -_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') - -# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration, -# including num_cores and num_hosts. -_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ - 'num_cores', - 'num_hosts', - 'num_of_cores_per_host', - 'topology', - 'devices', -]) - - -def _query_tpu_system_metadata(master_address, cluster_def=None, - query_topology=False): - """Automatically detects the TPU system metadata in the system.""" - tpu_core_count = 0 - devices = [] - device_dict = collections.defaultdict(list) - - # TODO(b/120564445): Replace with standard library for retries. - retry_count = 1 - while True: - logging.info('Querying Tensorflow master (%s) for TPU system metadata.', - master_address) - try: - with ops.Graph().as_default(): - with session_lib.Session( - master_address, - config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, - cluster_def)) as sess: - devices = sess.list_devices() - for device in devices: - match = _TPU_DEVICE_REG.match(device.name) - if match: - host_id = match.group(1) - core_id = match.group(2) - device_dict[host_id].append(core_id) - tpu_core_count += 1 - break - except errors.DeadlineExceededError: - msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' - 'not be ready (still scheduling) or the Tensorflow master address ' - 'is incorrect: got (%s).' % - (master_address)) - - # TODO(xiejw): For local or grpc master we might not need retry logic - # here. - if retry_count <= _RETRY_TIMES: - logging.warning('%s', msg) - logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) - retry_count += 1 - else: - raise ValueError(msg) - - num_of_cores_per_host = 0 - if tpu_core_count: - num_cores_per_host_set = set( - [len(core_ids) for core_ids in device_dict.values()]) - if len(num_cores_per_host_set) != 1: - raise RuntimeError( - 'TPU cores on each host is not same. This should not happen!. ' - 'devices: {}'.format(devices)) - num_of_cores_per_host = num_cores_per_host_set.pop() - - topology = None - if query_topology: - if not tpu_core_count: - raise RuntimeError( - 'Cannot find any TPU cores in the system (master address {}). ' - 'This usually means the master address is incorrect or the ' - 'TPU worker has some problems. Available devices: {}'.format( - master_address, devices)) - - topology = _obtain_topology(master_address, cluster_def) - - metadata = _TPUSystemMetadata( - num_cores=tpu_core_count, - num_hosts=len(device_dict), - num_of_cores_per_host=num_of_cores_per_host, - topology=topology, - devices=devices) - - if tpu_core_count: - logging.info('Found TPU system:') - logging.info('*** Num TPU Cores: %d', metadata.num_cores) - logging.info('*** Num TPU Workers: %d', metadata.num_hosts) - logging.info('*** Num TPU Cores Per Worker: %d', - metadata.num_of_cores_per_host) - for device in metadata.devices: - logging.info('*** Available Device: %s', device) - else: - logging.info('Failed to find TPU: %s', metadata) - return metadata - - -def _obtain_topology(master_address, cluster_def): - """Obtains TPU fabric topology.""" - try: - logging.info('Initializing TPU system (master: %s) to fetch topology ' - 'for model parallelism. This might take a while.', - master_address) - with ops.Graph().as_default(): - session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) - with session_lib.Session( - master_address, config=session_config) as sess: - topology = sess.run(tpu.initialize_system()) - return topology - except errors.DeadlineExceededError: - raise ValueError( - 'Fail to initialize TPU system with master (%s). ' - 'Please double check the TPU system is functional.' % ( - master_address)) - - -def get_session_config_with_timeout(timeout_in_secs, cluster_def): - """Returns a session given a timeout and a cluster configuration.""" - config = config_pb2.ConfigProto( - operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) - return config +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.tpu_system_metadata import * +# used by tests +from tensorflow.python.tpu.tpu_system_metadata import _query_tpu_system_metadata +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py index 0187b4bec6ecc55943bf48b9268a74e18ea5b488..673359b232d6857d468723873c449cb3e48168c7 100644 --- a/tensorflow/contrib/tpu/python/tpu/training_loop.py +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -1,214 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================= - -"""Library for constructing a training loop, suitable for TPUs.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.compiler import xla -from tensorflow.contrib.tpu.python.tpu import tpu_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops - - -def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop for TPUs. - - The set of loop-carried tensors corresponds to `inputs`. Both - `condition` and `body` take the current value of the loop-carried - tensors. 'body' additionally takes a tuple of infeed from - infeed_queue if infeed_queue is not None. `condition` must return a - single boolean value that determines whether iteration - continues. `body` must return an updated list of values for the - loop-carried tensors. - - Args: - condition: a Python function that builds the loop condition. - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop, or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - - Returns: - The final values of the loop-carried tensors. - - Raises: - TypeError: if body or condition has the wrong signature. - """ - del name - # Converts inputs to Tensors. - inputs = [] if inputs is None else [ops.convert_to_tensor(x) for - x in inputs] - input_types = [x.dtype for x in inputs] - input_arity = len(inputs) - - body_arg_error = xla.check_function_argument_count( - body, input_arity, infeed_queue) - if body_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s, but the loop body needs %s" % ( - input_arity, str([i.name for i in inputs]), body_arg_error)) - else: - raise TypeError( - "Supplied loop body function cannot be called with the specified " - "inputs. You specified %d inputs: %s and %d additional inputs from " - "infeed, but the computation needs %s" % (input_arity, str( - [i.name for i in inputs]), infeed_queue.number_of_tuple_elements, - body_arg_error)) - condition_arg_error = xla.check_function_argument_count( - condition, input_arity, None) - if condition_arg_error is not None: - if infeed_queue is None: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s" % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - else: - raise TypeError( - "Supplied loop condition function cannot be called with the " - "specified inputs. You specified %d inputs: %s, but the loop " - "condition needs %s. Note that infeed is not passed to the loop " - "condition." % (input_arity, str([i.name for i in inputs]), - condition_arg_error)) - - def condition_wrapper(*inputs): - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - return condition(*inputs) - - def body_wrapper(*inputs): - """Wrapper around `body` that handles infeed queues and control deps.""" - inputs = list(inputs) - - # Discards the dummy output added for arity-0 loops. - if input_arity == 0: - inputs = [] - - # Runs `body` with the dequeue_ops appended. - if infeed_queue: - number_of_shards = tpu_function.get_tpu_context().number_of_shards - if number_of_shards is None: - raise ValueError("Can't build training loop with infeed when there is " - "no tpu_shard_context. Are you building a loop or " - "graph directly rather than from inside tpu.rewrite, " - "tpu.batch_parallel, tpu.shard, or tpu.replicate?") - infeed_queue.set_number_of_shards(number_of_shards) - dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] - else: - dequeue_ops = [] - outputs = body(*(inputs + dequeue_ops)) - - # If the computation only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - - outputs = [ - o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) - for o in outputs - ] - - # Separates the returned Operations and Tensors. - output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] - - if outputs != output_tensors + output_operations: - raise ValueError( - "TPU training loop body must return zero or more Tensor values " - "followed by zero or more Operations.") - - output_types = [op.dtype for op in output_tensors] - if input_types != output_types: - raise TypeError( - "Mismatch between input types and output types for training loop " - "body: {} vs {}".format(input_types, output_types)) - - # Add the dequeue operations to output_operations to ensure they are run - # by the loop, even if the programmer's loop body does not use them. - output_operations += dequeue_ops - - # Add a dummy output, if needed. - if not output_tensors: - output_tensors = array_ops.constant(0) - - if output_operations: - # TODO(phawkins): in principle this is too restrictive since it serializes - # the training loop steps. In practice it does not matter since this loop - # will be compiled by XLA. - return control_flow_ops.tuple(output_tensors, - control_inputs=output_operations) - else: - return output_tensors - - # If the body has arity 0, add a dummy loop-carried value to which we can add - # control dependencies from any side-effecting operations. - if input_arity == 0: - inputs = [array_ops.constant(0)] - return control_flow_ops.while_loop( - condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) - - -def repeat(n, body, inputs=None, infeed_queue=None, name=None): - """Builds a training loop that executes a fixed number of iterations. - - The set of loop-carried tensors correspond to `inputs`. - `body` must be a function that takes and returns the values of the - loop-carried tensors. - - Args: - n: the number of loop iterations - body: a Python function that builds the loop body. - inputs: a list of initial values passed into the training loop or - None (equivalent to an empty list). - infeed_queue: if not None, the infeed queue from which to append a tuple - of arguments as inputs to condition. - name: (Deprecated) Does nothing. - Returns: - The final values of the loop-carried tensors. - Raises: - ValueError: if there is a type error. - """ - def _convert_to_list(xs): - if not isinstance(xs, (list, tuple)): - return [xs] - else: - return list(xs) - - def cond(i, *args): - del args - return i < n - - def body_wrapper(i, *args): - return [i + 1] + _convert_to_list(body(*args)) - - inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) - outputs = while_loop( - cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) - outputs = _convert_to_list(outputs) - if len(outputs) == 1: - # Returns the Op rather than an empty list. - return outputs[0].op - else: - return outputs[1:] +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.training_loop import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py index dfb8ce1d1821da05c853bb0d10b1db3a857ccb1b..8d9b70d46eb42c9a525eeafc51d07f0ad4241d52 100644 --- a/tensorflow/contrib/tpu/python/tpu/util.py +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -1,51 +1,23 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =================================================================== - -"""Utilities for the functionalities.""" +# ============================================================================== +"""Stub file to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time -import six - -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import training - -def check_positive_integer(value, name): - """Checks whether `value` is a positive integer.""" - if not isinstance(value, six.integer_types): - raise TypeError('{} must be int, got {}'.format(name, type(value))) - - if value <= 0: - raise ValueError('{} must be positive, got {}'.format(name, value)) - - -# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we -# release a tensorflow_estimator with MultiHostDatasetInitializerHook in -# python/estimator/util.py. -class MultiHostDatasetInitializerHook(training.SessionRunHook): - """Creates a SessionRunHook that initializes all passed iterators.""" - - def __init__(self, dataset_initializers): - self._initializers = dataset_initializers - - def after_create_session(self, session, coord): - del coord - start = time.time() - session.run(self._initializers) - logging.info('Initialized dataset iterators in %d seconds', - time.time() - start) +# pylint: disable=wildcard-import,unused-import +from tensorflow.python.tpu.util import * +# pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index 093765dc2098d2135a1d86aa44b23c13546267ee..4ceb6e9350f5167efc8f7266d4e748cc6fa4ffd6 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -432,7 +432,7 @@ def create_train_op(total_loss, else: # Make sure that variables_to_train are in tf.trainable_variables() for v in variables_to_train: - assert v in tf_variables.trainable_variables() + assert v.trainable or v in tf_variables.trainable_variables() assert variables_to_train diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 07dbd5ca8d65ec8232d33c016a7369c68a4c9e1f..ada08f95ae46ea06b3896ca3b1603277d62bf6fc 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -22,7 +22,9 @@ cc_library( "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:tensorflow", "//tensorflow/core/kernels:immutable_constant_op", ], diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3d92a836d1c21845407ec53bd46a24638e158e3b..06c108b38fbf1d4b796c313ce700332803c73ef9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -77,6 +77,7 @@ load( "//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_android", + "if_emscripten", "if_ios", "if_linux_x86_64", "if_mobile", @@ -87,10 +88,12 @@ load( "tf_copts", "tf_cuda_library", "tf_features_nomodules_if_android", + "tf_features_nomodules_if_emscripten", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_opts_nortti_if_emscripten", "transitive_hdrs", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") @@ -125,7 +128,6 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_compiler_hdrs", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", @@ -144,6 +146,7 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", + "tf_grpc_service_all", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -226,13 +229,15 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS # ones with individual proto_library targets. ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", - "protobuf/checkpointable_object_graph.proto", + "protobuf/trackable_object_graph.proto", "protobuf/control_flow.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", "protobuf/meta_graph.proto", "protobuf/named_tensor.proto", "protobuf/saved_model.proto", + "protobuf/saved_object_graph.proto", + "protobuf/struct.proto", "protobuf/tensorflow_server.proto", "protobuf/transport_options.proto", "util/test_log.proto", @@ -415,9 +420,8 @@ cc_library( name = "platform_protobuf", srcs = tf_platform_hdrs([ "protobuf.h", - ]) + tf_platform_srcs([ - "protobuf.cc", ]) + [ + "platform/protobuf.cc", "platform/protobuf_util.cc", "lib/core/status.h", ], @@ -436,6 +440,17 @@ cc_library( ], ) +cc_library( + name = "grpc_services", + srcs = [], + hdrs = [ + "platform/grpc_services.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = tf_grpc_service_all(), +) + cc_library( name = "human_readable_json", srcs = tf_platform_srcs(["human_readable_json.cc"]), @@ -454,10 +469,7 @@ cc_library( hdrs = ["platform/logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], - deps = [ - ":lib_proto_parsing", - "@protobuf_archive//:protobuf", - ], + deps = [":lib_proto_parsing"], ) filegroup( @@ -664,7 +676,7 @@ cc_library( name = "lib_proto_compiler", hdrs = [ "platform/protobuf_compiler.h", - ] + tf_additional_proto_compiler_hdrs(), + ], copts = tf_copts(), deps = tf_lib_proto_compiler_deps() + [ ":lib_proto_parsing", @@ -1049,13 +1061,13 @@ cc_library( "platform/default/integral_types.h", "platform/default/logging.h", "platform/default/mutex.h", - "platform/default/protobuf.h", "platform/default/thread_annotations.h", "platform/dynamic_annotations.h", "platform/macros.h", "platform/mutex.h", "platform/platform.h", "platform/prefetch.h", + "platform/protobuf.h", "platform/thread_annotations.h", "platform/types.h", "platform/cpu_info.h", @@ -1141,6 +1153,13 @@ tf_gen_op_libs( deps = [":protos_all_cc"], ) +tf_gen_op_libs( + op_lib_names = [ + "mkl_array_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "audio_ops", @@ -1161,6 +1180,29 @@ tf_gen_op_libs( deps = [":lib"], ) +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + deps = [ + ":lib", + ":lib_proto_parsing", + ":protos_all_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1277,10 +1319,23 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), + ] + if_mkl([ + ":mkl_array_ops_op_lib", + ":mkl_nn_ops_op_lib", + ]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1382,7 +1437,7 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( name = "all_kernels_impl", - visibility = ["//visibility:private"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", @@ -1452,6 +1507,7 @@ cc_library( "//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", + "//tensorflow/core/kernels:mkl_requantize_ops", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", @@ -1540,6 +1596,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", + ":ops", ":protos_all_cc", ":shape_inference_testutil", ":tensor_testutil", @@ -1770,6 +1827,29 @@ cc_library( ], ) +cc_library( + name = "emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + srcs = if_emscripten(["//tensorflow/core:mobile_srcs_no_runtime"]), + copts = ["-DSUPPORT_SELECTIVE_REGISTRATION"] + tf_opts_nortti_if_emscripten(), + defines = ["TENSORFLOW_LITE_PROTOS"], + linkopts = ["-lz"], + tags = [ + "manual", + "notap", + ], + visibility = ["//visibility:public"], + deps = [ + ":emscripten_proto_lib_no_rtti_lite_runtime", + ":mobile_additional_lib_deps", + ":stats_calculator_portable", + "//third_party/eigen3", + "@double_conversion//:double-conversion", + "@nsync//:nsync_cpp", + "@zlib_archive//:zlib", + ], + alwayslink = 1, +) + # Native library support for iOS applications. # # bazel build --config=ios_x86_64 \ @@ -1863,6 +1943,7 @@ filegroup( "**/*testutil*", "**/*testlib*", "**/*main.cc", + "**/tpu_*", ], ), visibility = ["//visibility:public"], @@ -2248,6 +2329,7 @@ cc_library( "platform/**/logging.cc", "platform/**/human_readable_json.cc", "platform/abi.cc", + "platform/protobuf.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -2274,6 +2356,8 @@ cc_library( ":lib_proto_parsing", ":abi", ":core_stringpiece", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//third_party/eigen3", "//tensorflow/core/platform/default/build_config:platformlib", "@snappy", @@ -2653,7 +2737,6 @@ tf_cuda_library( "example/**/*.cc", "framework/**/*.cc", "util/**/*.cc", - ] + [ "graph/edgeset.cc", "graph/graph.cc", "graph/graph_def_builder.cc", @@ -2898,6 +2981,7 @@ tf_cuda_library( CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", + "common_runtime/shared_counter.h", "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", "common_runtime/hierarchical_tree_broadcaster.h", @@ -2922,6 +3006,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/lower_if_while.h", "common_runtime/lower_while_op.h", "common_runtime/memory_types.h", + "common_runtime/metrics.h", "common_runtime/mkl_cpu_allocator.h", "common_runtime/optimization_registry.h", "common_runtime/pending_counts.h", @@ -2933,6 +3018,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/rendezvous_mgr.h", "common_runtime/rendezvous_util.h", "common_runtime/ring_reducer.h", + "common_runtime/ring_alg.h", + "common_runtime/ring_gatherer.h", "common_runtime/session_factory.h", "common_runtime/single_threaded_cpu_device.h", "common_runtime/stats_publisher_interface.h", @@ -2957,6 +3044,8 @@ tf_cuda_library( "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", "common_runtime/collective_util.cc", + "common_runtime/colocation_graph.cc", + "common_runtime/colocation_graph.h", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -2977,6 +3066,7 @@ tf_cuda_library( "common_runtime/lower_if_while.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_types.cc", + "common_runtime/metrics.cc", "common_runtime/mkl_cpu_allocator.cc", "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", @@ -2989,6 +3079,8 @@ tf_cuda_library( "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", "common_runtime/rendezvous_util.cc", + "common_runtime/ring_alg.cc", + "common_runtime/ring_gatherer.cc", "common_runtime/ring_reducer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", @@ -3047,7 +3139,6 @@ tf_cuda_library( ":framework", ":graph", ":lib", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/grappler:grappler_item", @@ -3078,15 +3169,6 @@ cc_library( deps = [":lib_internal"], ) -tf_cuda_library( - name = "metrics", - srcs = ["common_runtime/metrics.cc"], - hdrs = ["common_runtime/metrics.h"], - deps = [ - ":lib", - ], -) - tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3103,7 +3185,6 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", - ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", @@ -3470,6 +3551,7 @@ tf_cc_tests( "platform/vmodule_benchmark_test.cc", ], deps = [ + ":core_cpu_internal", ":lib", ":lib_internal", ":lib_test_internal", @@ -3679,6 +3761,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "lib_strings_proto_serialization_test", + srcs = ["lib/strings/proto_serialization_test.cc"], + deps = [ + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + ":test_main", + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "lib_random_weighted_picker_test", size = "medium", @@ -3886,7 +3982,6 @@ tf_cc_test( "ops/cudnn_rnn_ops_test.cc", ], deps = [ - ":cudnn_rnn_ops", "//tensorflow/core", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -3946,6 +4041,35 @@ tf_cc_tests_gpu( ], ) +tf_cc_tests_gpu( + name = "ring_gatherer_test", + size = "medium", + srcs = [ + "common_runtime/ring_gatherer_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":all_kernels", + ":core", + ":core_cpu", + ":core_cpu_internal", + ":direct_session_internal", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":lib", + ":lib_internal", + ":ops", + ":protos_all_cc", + ":protos_test_cc", + ":test", + ":test_main", + ":testlib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", size = "medium", @@ -4484,7 +4608,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/core/kernels:cwise_op", "//third_party/eigen3", - ], + ] + if_mkl([":mkl_array_ops_op_lib"]), ) tf_cc_test( @@ -5037,6 +5161,39 @@ transitive_hdrs( # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). +load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") + +genrule( + name = "emscripten_proto_config_lite_runtime", + outs = ["emscripten_proto_config_lite_runtime.asciipb"], + cmd = tf_genrule_cmd_append_to_srcs("optimize_mode:LITE_RUNTIME"), + visibility = ["//visibility:private"], +) + +# We are keeping the "android" version of tf_android_core_proto_headers. All it does is +# normalize CORE_PROTO_SRCS to generate valid output file names. +tf_portable_proto_library( + name = "emscripten_proto_lib_no_rtti_lite_runtime", + config = ":emscripten_proto_config_lite_runtime", + copts = tf_opts_nortti_if_emscripten(), + features = tf_features_nomodules_if_emscripten(), + header_outs = tf_android_core_proto_headers(CORE_PROTO_SRCS) + ["//google/protobuf/any.proto.h"], + link_full_protobuf = False, + prefix_dir = "emscripten_proto_no_rtti", + proto_deps = [ + ":protos_all_cc", + "@protobuf_archive//:protobuf", + ], + visibility = ["//visibility:public"], +) + +# There is currently no need for a full proto version of emscripten tf lib lite. +alias( + name = "emscripten_lib_lite_no_runtime", + actual = "//tensorflow/core:emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime", + visibility = ["//visibility:public"], +) + alias( name = "android_srcs_no_runtime", actual = ":mobile_srcs_no_runtime", diff --git a/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d6f28bd022bcd843aa3a7aeb8b1b257a3b3ddfd3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AllToAll.pbtxt @@ -0,0 +1,67 @@ +op { + graph_op_name: "AllToAll" + in_arg { + name: "input" + description: <