diff --git a/.gitignore b/.gitignore index 57d84228cfd037325716b5faa56c17f7424fe713..90324058600bee46af56e49028977971848a80de 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,7 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata -/tensorflow/lite/downloads/** +/tensorflow/lite/tools/make/downloads/** /tensorflow/lite/gen/** /tensorflow/lite/examples/ios/simple/data/*.txt /tensorflow/lite/examples/ios/simple/data/*.tflite diff --git a/WORKSPACE b/WORKSPACE index 17961829a605c2d1f2d2ba86a7c30c47618c139b..0c7bc085b512b084b9470abe17326d7c119aa327 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +http_archive( + name = "base_images_docker", + sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", + strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", + urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], +) + +http_archive( + name = "bazel_toolchains", + sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", + strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", + ], +) + +http_archive( + name = "io_bazel_rules_docker", + sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", + strip_prefix = "rules_docker-0.5.1", + urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +) + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") + +remote_config_workspace() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. @@ -79,3 +106,4 @@ new_http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) + diff --git a/configure.py b/configure.py index 2eeeceb3399c79775ce62b9569e940e469141a17..234561d94a46f57c4de5ca487360e2d5a3dfdb2f 100644 --- a/configure.py +++ b/configure.py @@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -1555,6 +1555,9 @@ def main(): check_bazel_version('0.15.0') reset_tf_configure_bazelrc() + # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later + write_to_bazelrc('import %workspace%/tools/bazel.rc') + cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 11b42f349df89c605f4ce1130033a85c920258c9..17577afecb74b7008db5a282255278b35ed138a6 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = ( TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) +# @unused +TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( + TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -213,31 +218,31 @@ config_setting( # config_setting( name = "no_aws_support", - define_values = {"no_aws_support": "false"}, + define_values = {"no_aws_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_gcp_support", - define_values = {"no_gcp_support": "false"}, + define_values = {"no_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_hdfs_support", - define_values = {"no_hdfs_support": "false"}, + define_values = {"no_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_ignite_support", - define_values = {"no_ignite_support": "false"}, + define_values = {"no_ignite_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "no_kafka_support", - define_values = {"no_kafka_support": "false"}, + define_values = {"no_kafka_support": "true"}, visibility = ["//visibility:public"], ) @@ -350,8 +355,9 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", - "//tensorflow_estimator/...", + "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", + "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", ], ) @@ -553,18 +559,24 @@ genrule( }), outs = ["__init__.py"], cmd = select({ - "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", - "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)", }), ) gen_api_init_files( name = "tf_python_api_gen_v1", - srcs = ["api_template_v1.__init__.py"], + srcs = [ + "api_template_v1.__init__.py", + "compat_template_v1.__init__.py", + ], api_version = 1, + compat_api_versions = [1], + compat_init_templates = ["compat_template_v1.__init__.py"], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1, + output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, output_package = "tensorflow._api.v1", + root_file_name = "v1.py", root_init_template = "api_template_v1.__init__.py", ) @@ -580,6 +592,7 @@ gen_api_init_files( output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", + root_file_name = "v2.py", root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 0d49756838505289a960a6cabeb7cab02fad995b..2efb8846c6837a3935e0a8439a18838cb2bea804 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -34,7 +34,8 @@ from tensorflow.python.platform import flags # pylint: disable=g-import-not-at- # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +# We're using bitwise, but there's nothing special about that. +_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 16f633643d4726f6e2d1a23c3b192d48dbbc8f14..f653e581bf3beda9fdbf8fb7905a4f9fe170e7fb 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -60,6 +60,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:op_gen_lib", + "//tensorflow/core/distributed_runtime:server_lib", ], }), ) @@ -95,6 +96,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -119,7 +121,8 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -172,6 +175,28 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "kernels", + srcs = [ + "kernels.cc", + ], + hdrs = [ + "kernels.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + "//tensorflow/core:framework", + ], + }), +) + # ----------------------------------------------------------------------------- # Tests @@ -199,7 +224,7 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ - ":test_op.so", + ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], kernels = [":test_op_kernel"], @@ -207,7 +232,10 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), - tags = ["noasan"], + tags = [ + "no_oss", # http://b/119522529 + "noasan", + ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -218,6 +246,7 @@ tf_cuda_cc_test( "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/compiler/jit", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", @@ -235,7 +264,7 @@ tf_cuda_cc_test( tf_cc_test( name = "c_api_experimental_test", - size = "small", + size = "medium", srcs = ["c_api_experimental_test.cc"], data = ["testdata/tf_record"], linkopts = select({ @@ -246,8 +275,11 @@ tf_cc_test( # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), deps = [ + ":c_api", ":c_api_experimental", ":c_test_util", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -284,8 +316,8 @@ tf_cc_test( ) tf_custom_op_library( - name = "test_op.so", - srcs = ["test_op.cc"], + name = "test_op1.so", + srcs = ["test_op1.cc"], ) tf_kernel_library( @@ -298,6 +330,30 @@ tf_kernel_library( alwayslink = 1, ) +tf_cuda_cc_test( + name = "kernels_test", + size = "small", + srcs = ["kernels_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- # Python API target diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 4540dcd6638a58c25628dccd2fa78f1fe06bef1d..f13e8777dff164bcd8eedf46310ae846abd0c804 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2810,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { } return ret; } + +// TF_Server functions ---------------------------------------------- + +#ifndef __ANDROID__ +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // __ANDROID__ + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Start(); +#endif +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Stop(); +#endif +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Join(); +#endif +} + +const char* TF_ServerTarget(TF_Server* server) { +#ifdef __ANDROID__ + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { delete server; } + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index da8ad1cec59e328b9f1a77f81416651a618e97d3..3d56268110edbe96616201d15a69cc8c84d3115a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1668,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index fabe2fa0f60bc8baafa7f83802da74bb7ab93c6d..69de4cb711ef89734af3729c5e5518c14a7f5738 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,13 +15,18 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -8739,8 +8744,55 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } -TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, - const char* errMsg) { +struct TFE_ExecuteOpNotification { + TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} + tensorflow::Notification n; + std::unique_ptr thread; + std::unique_ptr status; +}; + +TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op, + TFE_TensorHandle** retvals, + int* num_retvals, + TF_Status* status) { + TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; + + n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( + tensorflow::ThreadOptions(), "ExecuteOpThread", + [op, retvals, num_retvals, n]() { + TFE_Execute(op, retvals, num_retvals, n->status.get()); + n->n.Notify(); + })); + + return n; +} + +void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status) { + if (notification == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification is a nullptr."); + + return; + } + if (notification->thread == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification didn't start a thread correctly. Cleaning up " + "this notification. Please re-execute the operation to get a new " + "notification."); + + delete notification; + return; + } + + notification->n.WaitForNotification(); + + status->status = notification->status->status; + + delete notification; +} + +void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } @@ -8800,3 +8852,21 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, // The returned string is owned by OpRegistry, so liveness is not a concern. return input_arg.number_attr().c_str(); } + +int TF_OpIsStateful(const char* op_type, TF_Status* status) { + const tensorflow::OpRegistrationData* op_reg_data; + status->status = + tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data); + if (!status->status.ok()) { + return 0; + } + return op_reg_data->op_def.is_stateful(); +} + +void TF_InitMain(const char* usage, int* argc, char*** argv) { + tensorflow::port::InitMain(usage, argc, argv); +} + +int TF_PickUnusedPortOrDie() { + return tensorflow::internal::PickUnusedPortOrDie(); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6639b0be72bdf81d0e3c806770364d7bc5082ad2..c04cd441bfbd89dafc8b3f0882ab06cd98a1b6fb 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; + +// Allows invoking a kernel asynchronously, and explicitly returns a +// notification that can be waited upon. This always executes the kernel in a +// new thread. +// 1. `retvals` and `num_retvals` can only be consumed after +// `TFE_ExecuteOp` returns successfully. They shouldn't be used +// if the return is unsuccessful +// 2. These new APIs cannot be used together with the TFE context level async +// support. +TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread( + TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status); + +// Waits to complete the op execution, and cleans up the notification. +// Errors reported by op execution are set in `status`. +TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status); + TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); @@ -209,6 +228,19 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( const char* op_name, int input_index, TF_Status* status); +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type, + TF_Status* status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index c6effd39697e0397278770b53e98508074f99862..daa7701b7fe7e8ce757b6504329cf6434ad39778 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -162,5 +164,137 @@ protocol: "grpc" TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, IsStateful) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + int assign = TF_OpIsStateful("AssignAddVariableOp", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(assign, 1); + int id = TF_OpIsStateful("Identity", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(id, 0); +} + +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + TFE_Op* matmul_op = MatMulOp(ctx, m, m); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + auto* r = + TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(r, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteOp(matmul_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + +// Perform a send/recv test. Recv blocks, so they need to be executed +// asynchronously. +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4. + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + // Build a send op. + TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(send_op, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + string tensor_name = "Tensor"; + TFE_OpSetAttrType(send_op, "T", TF_FLOAT); + TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + string send_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234); + string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(send_op, "client_terminated", true); + + // Build a recv op. + TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT); + TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234); + TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(recv_op, "client_terminated", true); + + TFE_TensorHandle* send_retvals; + int send_num_retvals = 0; + auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals, + &send_num_retvals, status); + + TFE_TensorHandle* recv_retvals[1] = {nullptr}; + int recv_num_retvals = 1; + auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0], + &recv_num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(send_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, product[0]); + EXPECT_EQ(2, product[1]); + EXPECT_EQ(3, product[2]); + EXPECT_EQ(4, product[3]); + + TFE_DeleteOp(send_op); + TFE_DeleteOp(recv_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(recv_retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index f68f8a3e90a971b5e4a024feaf26ba498afc48da..28b9f8df9c873ee394eb6a241dd9ac06ba6c8796 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -392,26 +392,26 @@ Status ProcessInputs( EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { input_tensors->reserve(ninputs); for (int i = 0; i < ninputs; ++i) { - const Node& node = inputs[i].oper->node; + Node* node = &inputs[i].oper->node; int idx = inputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - input_tensors->emplace_back(&node, idx); + input_tensors->emplace_back(node, idx); - const auto& iter = input_nodes->find(&node); + const auto& iter = input_nodes->find(node); if (iter == input_nodes->end()) { - input_nodes->insert({&node, {idx}}); + input_nodes->insert({node, {idx}}); } else { auto& indices = iter->second; if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { - return InvalidArgument("TF_Output ", node.name(), ":", idx, + return InvalidArgument("TF_Output ", node->name(), ":", idx, " appears more than once in the input list"); } indices.push_back(idx); @@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { - const Node& node = outputs[i].oper->node; + Node* node = &outputs[i].oper->node; int idx = outputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing output ", i, " from function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while creating function '", fn_name, "'"); - output_tensors->emplace_back(&node, idx); + output_tensors->emplace_back(node, idx); } return Status::OK(); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 95652a11378d6276b5ba6540a07baa15aa77cc1c..5ba26d3c585350aa510f9970cbfc246a9a108543 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -25,6 +25,7 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -179,6 +180,15 @@ struct TF_ApiDefMap { tensorflow::mutex lock; }; +#ifndef __ANDROID__ +struct TF_Server { + TF_Server(std::unique_ptr server); + + const tensorflow::string target; + std::unique_ptr server; +}; +#endif + namespace tensorflow { class TensorCApi { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index b0dc0363fdb266a7bb8babcd41ac469b5e763551..d5934a10395ae094f65d3bc8b6cd7b94dbd32410 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - // Load the library. - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - TF_Code code = TF_GetCode(status); - string status_msg(TF_Message(status)); - TF_DeleteStatus(status); - ASSERT_EQ(TF_OK, code) << status_msg; +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op1.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi1", op_list.op(0).name()); + TF_DeleteLibraryHandle(lib); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) { TF_Buffer* op_list_buffer = TF_GetAllOpList(); tensorflow::OpList op_list; @@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) { EXPECT_TRUE(found); TF_DeleteBuffer(op_list_buffer); } - -#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) - { - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); - } -#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) - - TF_DeleteLibraryHandle(lib); } void TestEncodeDecode(int line, const std::vector& data) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3ee31a6a7ac641bbd3fc4c05568b61e433a1d523..ba3d8533db7623b8fa7fdf35093abcd1450776b1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -69,7 +69,7 @@ tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3554ec0bf3202b54bfc38d67e51b89df19832302..192044915f06e3644aebb200a229cce5f220752b 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/host_info.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -404,8 +405,7 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { "The passed in handle is a nullptr"); return nullptr; } - tensorflow::Device* d = nullptr; - status->status = h->handle->OpDevice(&d); + tensorflow::Device* d = h->handle->op_device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } @@ -459,13 +459,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { const char* name = op_or_function_name; // Shorthand const tensorflow::AttrTypeMap* types; - status->status = tensorflow::AttrTypeMapForOp(name, &types); - if (status->status.ok()) return new TFE_Op(ctx, name, types); - if (TF_GetCode(status) == TF_NOT_FOUND) { - if (ctx->context.FindFunctionByName(name)) { - status->status = tensorflow::Status::OK(); - return new TFE_Op(ctx, name, nullptr); + bool is_function = false; + status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); + if (status->status.ok()) { + if (is_function && !ctx->context.FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; } + return new TFE_Op(ctx, name, is_function, types); } return nullptr; } @@ -498,12 +505,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->operation.is_function()) { - status->status = tensorflow::errors::Unimplemented( - "TODO(apassos): Support for attributes for TensorFlow functions is not " - "ready yet."); - return TF_ATTR_INT; // The compiler requires that we return something. - } status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), attr_name, &ret, is_list); return ret; diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 5006b76f1981d068e99a2c081115ebb3a66d8c7f..52b0824552855860dfb138f3ac9a5d3afa7dc965 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } - tensorflow::Device* device; - status->status = handle->handle->Device(&device); - if (!status->status.ok()) { - return nullptr; - } - #ifdef TENSORFLOW_EAGER_USE_XLA + tensorflow::Device* device = handle->handle->device(); + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. tensorflow::XlaDevice* xla_device = dynamic_cast(device); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 104d52430cf7aa14d4d2a335a1b96e667f21ce87..67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -79,10 +79,6 @@ struct TFE_TensorHandle { tensorflow::Device* op_device) : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} - TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - tensorflow::EagerContext* ctx) - : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; @@ -97,10 +93,9 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a - // primitive operation. - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : operation(&ctx->context, op, t) {} + TFE_Op(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t) + : operation(&ctx->context, op, is_function, t) {} tensorflow::EagerOperation operation; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 55331022b9dbd0696928fa44430f340f371432ac..0045bb5622647974a3c9f2cdf35bc21e126b4f52 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -589,9 +589,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); const int num_devices = TF_DeviceListCount(devices); + bool has_gpu0 = false; + bool has_gpu1 = false; + for (int i = 0; i < num_devices; ++i) { + const char* dev = TF_DeviceListName(devices, i, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + string device_name(dev); + if (device_name.find("GPU:0") != string::npos) { + has_gpu0 = true; + } + if (device_name.find("GPU:1") != string::npos) { + has_gpu1 = true; + } + } const char* kCPUDevice = "CPU:0"; - if (num_devices < 3) { + if (!has_gpu0 || !has_gpu1) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5ba55a203ff70cc64c07e96b5a869a1f11c9334e..5c11f51e8749de84547ae873f5f55ebd42bc4b3d 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -141,8 +141,9 @@ class GradientTape { // null. The result is populated with one tensor per target element. Status ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result); @@ -396,6 +397,7 @@ template Status InitialGradients( const VSpace& vspace, gtl::ArraySlice target_tensor_ids, + gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, gtl::FlatMap>* result) { @@ -425,8 +427,13 @@ Status InitialGradients( "none of operations outputs match expected tensor"); } } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + // This target tensor was not generated by any operation recorded on + // the tape, so no gradient needs to be computed from it unless this + // target is also a source. + auto source_tensor = sources_that_are_targets.find(id); + if (source_tensor != sources_that_are_targets.end()) { + (*result)[id].push_back(vspace.Ones(source_tensor->second)); + } } } else { (*result)[id].push_back(output_gradients[i]); @@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template Status GradientTape::ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_ids, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), @@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient( std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; - Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca69345264607ac689fb556b4f5c9bc08ea5eb88 --- /dev/null +++ b/tensorflow/c/kernels.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/kernels.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +// This file forms the basis of a stable ABI for third-party kernel +// implementations. It is crucial that changes to this file are made cautiously +// and with a focus on maintaining both source and binary compatibility. + +struct TF_KernelBuilder { + ::tensorflow::KernelDefBuilder* cc_builder; + + void* (*create_function)(TF_OpKernelConstruction*); + void (*compute_function)(void*, TF_OpKernelContext*); + void (*delete_function)(void*); +}; + +TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) { + TF_KernelBuilder* result = new TF_KernelBuilder; + result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name); + result->cc_builder->Device(device_name); + result->create_function = create_func; + result->compute_function = compute_func; + result->delete_function = delete_func; + return result; +} + +void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { + DCHECK_NE(builder, nullptr); + delete builder->cc_builder; + delete builder; +} + +namespace tensorflow { +namespace { + +// An OpKernel whose methods delegate to C function pointers. +class COpKernel : public OpKernel { + public: + explicit COpKernel(OpKernelConstruction* ctx, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) + : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) { + if (create_func != nullptr) { + c_kernel_ = + (*create_func)(reinterpret_cast(ctx)); + } else { + c_kernel_ = nullptr; + } + } + + void Compute(OpKernelContext* ctx) override { + (*compute_func_)(c_kernel_, reinterpret_cast(ctx)); + } + + ~COpKernel() override { + if (delete_func_ != nullptr) { + (*delete_func_)(c_kernel_); + } + } + + private: + void (*compute_func_)(void*, TF_OpKernelContext* context); + void (*delete_func_)(void*); + void* c_kernel_; +}; + +// A KernelFactory that returns COpKernel instances. +class KernelBuilderFactory + : public ::tensorflow::kernel_factory::OpKernelFactory { + public: + explicit KernelBuilderFactory(TF_KernelBuilder* builder) + : builder_(builder) {} + ::tensorflow::OpKernel* Create( + ::tensorflow::OpKernelConstruction* context) override { + return new ::tensorflow::COpKernel(context, builder_->create_function, + builder_->compute_function, + builder_->delete_function); + } + ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); } + + private: + TF_KernelBuilder* builder_; +}; +} // namespace +} // namespace tensorflow + +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, + TF_Status* status) { + using tensorflow::register_kernel::Name; + + tensorflow::kernel_factory::OpKernelRegistrar( + builder->cc_builder->Build(), name, + absl::make_unique(builder)); + + TF_SetStatus(status, TF_OK, ""); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..2518789a3c141755d0b3373d53642c487331f68b --- /dev/null +++ b/tensorflow/c/kernels.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_H_ +#define TENSORFLOW_C_KERNELS_H_ + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for TensorFlow Kernels. +// +// This API allows developers to register custom kernel implementations for +// TensorFlow. +// +// See c_api.h header comments for a discussion about API conventions. +// +// Users wishing to extend TensorFlow with new kernels will call +// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with +// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided +// kernels when necessary. + +struct TF_KernelBuilder; +struct TF_OpKernelConstruction; +struct TF_OpKernelContext; + +// Allocates a new kernel builder and returns a pointer to it. +// +// If non-null, TensorFlow will call create_func when it needs to instantiate +// the kernel. The pointer returned by create_func will be passed to +// compute_func and delete_func, thereby functioning as a "this" pointer for +// referring to kernel instances. +// +// The TF_OpKernelConstruction pointer passed to create_func is owned by +// TensorFlow and will be deleted once create_func returns. It must not be used +// after this. +// +// When TensorFlow needs to perform a computation with this kernel, it will +// call compute_func. This function will receive the pointer returned by +// create_func (or null if no create_func was provided), along with the inputs +// to the computation. +// +// The TF_OpKernelContext pointer received by compute_func is owned by +// TensorFlow and will be deleted once compute_func returns. It must not be used +// after this. +// +// Finally, when TensorFlow no longer needs the kernel, it will call +// delete_func if one is provided. This function will receive the pointer +// returned in `create_func` or nullptr if no `create_func` was provided. +// +// The caller should pass the result of this function to +// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for +// some reason, the kernel builder will not be registered, the caller should +// delete it with TF_DeleteKernelBuilder. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This call takes ownership of the `builder` pointer. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, + TF_KernelBuilder* builder, + TF_Status* status); + +// Deletes the given TF_KernelBuilder. This should be called only if the kernel +// builder is not registered with TensorFlow via TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_H_ diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e706c7c1d96ee1781d8efc0f28c5e0cbcbc80861 --- /dev/null +++ b/tensorflow/c/kernels_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +struct MyCustomKernel { + bool created; + bool compute_called; +}; + +static bool delete_called = false; + +static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { + LOG(INFO) << "Wow, actually got into creation"; + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + return s; +} + +static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { + struct MyCustomKernel* s = static_cast(kernel); + s->compute_called = true; +} + +static void MyDeleteFunc(void* kernel) { + struct MyCustomKernel* s = static_cast(kernel); + EXPECT_TRUE(s->created); + EXPECT_TRUE(s->compute_called); + delete_called = true; + delete s; +} + +// Tests registration of a single C kernel and checks that calls through the +// C/C++ boundary are being made. +TEST(TestKernel, TestRegisterKernelBuilder) { + const char* kernel_name = "SomeKernelName"; + const char* op_name = "FooOp"; + const char* device_name = "barDev"; + + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp("FooOp", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + ::tensorflow::KernelList list; + list.ParseFromArray(buf->data, buf->length); + ASSERT_EQ(1, list.kernel_size()); + ASSERT_EQ("barDev", list.kernel(0).device_type()); + TF_DeleteBuffer(buf); + TF_DeleteStatus(status); + } + + REGISTER_OP("FooOp") + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + { + ::tensorflow::NodeDef def; + def.set_op("FooOp"); + def.set_device("bar"); + def.add_input("input1"); + def.add_input("input2"); + ::tensorflow::Status status; + std::unique_ptr<::tensorflow::OpKernel> kernel = + ::tensorflow::CreateOpKernel(::tensorflow::DeviceType("barDev"), + nullptr, nullptr, def, 1, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} diff --git a/tensorflow/c/test_op1.cc b/tensorflow/c/test_op1.cc new file mode 100644 index 0000000000000000000000000000000000000000..b22cc9aef2b344282f45340ff12ee849935a26f9 --- /dev/null +++ b/tensorflow/c/test_op1.cc @@ -0,0 +1,23 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc"); + +} // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c18b07603ae3841d3581741ab5a43f2e8b628356..a09becc49b10d2c58f98fbcc11df5190f794c1d4 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -170,6 +170,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -488,6 +489,7 @@ tf_gen_op_wrappers_cc( "image_ops", "io_ops", "linalg_ops", + "list_ops", "logging_ops", "lookup_ops", "manip_ops", @@ -516,6 +518,8 @@ tf_gen_op_wrappers_cc( ":array_ops", ":const_op", ":math_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", ], ) diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e..ec116f68cf4b61c9b2d15065916ad9169017b659 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -193,6 +193,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, std::vector* asset_file_defs) { + // With SavedModel v2, we write asset file def into metagraph instead of + // collection, so read from metagraph first. + if (meta_graph_def.asset_file_def_size() > 0) { + for (const auto& asset : meta_graph_def.asset_file_def()) { + asset_file_defs->push_back(asset); + } + return Status::OK(); + } + // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 6c29f09cde7ee17c11cb44ce48d8e9128daae4d0..16151e77737429f4fbf690fc34b12a70bacebdc4 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -93,7 +93,7 @@ cc_library( ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348c5cdfdcaed635ba527e9f0bfd6092d..d548de8c44285f6d21dd778db464a31e1b19645b 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) { return errors::InvalidArgument("Must specify --cpp_class"); } codegen_opts.gen_hlo_profile_printer_data = - xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + xla::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); @@ -132,7 +132,7 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0c41e095c7bc73e517d4c11c590a21439db1e3da..682c0f0cb05c8c83acac28c8f3abf4f5e355e7c0 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -21,7 +21,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") @@ -52,6 +51,7 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -65,6 +65,7 @@ cc_library( ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), alwayslink = 1, @@ -75,10 +76,10 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":flags", ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -190,6 +191,7 @@ cc_library( "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:stack", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", @@ -208,6 +210,18 @@ cc_library( # Internal targets below this point. +cc_library( + name = "flags", + srcs = ["flags.cc"], + hdrs = ["flags.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "common", srcs = [ @@ -241,6 +255,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) @@ -253,6 +268,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:core_cpu", @@ -263,6 +279,22 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "xla_compilation_cache_test", + srcs = [ + "xla_compilation_cache_test.cc", + ], + deps = [ + ":xla_compilation_cache", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -468,6 +500,7 @@ cc_library( deps = [ ":common", ":encapsulate_util", + ":flags", ":shape_inference_helpers", ":union_find", ":xla_cluster_util", @@ -475,8 +508,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", @@ -500,6 +531,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -524,25 +556,6 @@ cc_library( hdrs = ["union_find.h"], ) -cc_library( - name = "producer_consumer_queue", - hdrs = ["producer_consumer_queue.h"], - deps = ["//tensorflow/core:lib"], -) - -tf_cc_test( - name = "producer_consumer_queue_test", - size = "small", - srcs = ["producer_consumer_queue_test.cc"], - deps = [ - ":producer_consumer_queue", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -606,6 +619,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -648,31 +662,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "xla_launch_util_test", - size = "small", - srcs = ["xla_launch_util_test.cc"], - deps = [ - ":common", - ":xla_compilation_cache", - ":xla_launch_util", - ":xla_tensor", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core/kernels:variable_ops", - ], -) - cc_library( name = "xla_fusion_optimizer", srcs = ["xla_fusion_optimizer.cc"], diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 054f31ba3352b2215e6b0448c8ec8a70cb98b8e5..9f4042630edaec1b9519b6434d859a48372e8b15 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) { return errors::Internal("Could not find compilation device ", device_type.type()); } - *result = registration->requires_compilation; + *result = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); } @@ -319,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { return IsXlaCompiledKernel(*n); }); - bool lazy_compilation_enabled = enable_lazy_compilation_ - ? *enable_lazy_compilation_ - : legacy_flags::GetBuildXlaOpsPassFlags() - .tf_xla_enable_lazy_compilation; + bool lazy_compilation_enabled = + enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 617e31488c7daeb714c0ff7056b786e4eaf7873f..8a73101c184e6190921fd7729742922bd96f4bcf 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output loop_cond = ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), latch.output_true, increment_by); Output next_iteration = @@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue( value, frame_name); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output next_iteration = ops::NextIteration( root.WithOpName(prefix + "/next_iteration"), latch.output_true); CHECK(root.graph() diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 28ec37b1b9c8a1a306b5e778bac5b6ba01c2c997..bcc3213285bee2a2094bd6c39b37ba95874d90ed 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, continue; } else if (src_xla_computation && !dst_xla_computation) { if (src_outside_compilation) { - // Case 1d: outside compilation to host computation control edge. + // Case 1c: outside compilation to host computation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } } else if (!src_xla_computation && dst_xla_computation) { if (dst_outside_compilation) { - // Case 1d: host computation control to outside compilation edge. + // Case 1c: host computation control to outside compilation edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( @@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, } else { // src_xla_computation && dst_xla_computation if (*src_xla_computation != *dst_xla_computation) { if (src_outside_compilation && dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. + // Case 1b: outside compilation to outside compilation control edge. edges_to_remove.push_back(e); TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1b: outside compilation to another XLA computaition control + // Case 1a: outside compilation to another XLA computaition control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->src(), kXlaConnectedToOtherXlaComputationAttrName, *dst_xla_computation)); } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1b: another XLA computaition to outside compilation control + // Case 1a: another XLA computaition to outside compilation control // edge. TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, *src_xla_computation)); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation) { - if (*src_outside_compilation != *dst_outside_compilation) { - // Case 1c: outside compilation to outside compilation control edge. - edges_to_remove.push_back(e); - - TF_RETURN_IF_ERROR(AppendToListAttr( - e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); - } - } else if (src_outside_compilation && !dst_outside_compilation) { - // Case 1a: outside compilation to its XLA computation control edge. - ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); - } else if (!src_outside_compilation && dst_outside_compilation) { - // Case 1a: XLA computation to outside compilation in it control edge. - ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); - } } } } @@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g, edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); } - } else { // *src_xla_computation == *dst_xla_computation - if (src_outside_compilation && dst_outside_compilation && - *src_outside_compilation != *dst_outside_compilation) { - edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); - VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); - } } } @@ -594,14 +572,242 @@ Status AddControlDependencies( return Status::OK(); } +// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges to remove. We should not remove the edge while iterating. + std::vector edges_to_remove; + for (const Edge* e : g->edges()) { + if (!e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation) { + if (*src_outside_compilation != *dst_outside_compilation) { + // Case 1a: outside compilation to outside compilation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, + e->src()->name())); + } + } else if (src_outside_compilation && !dst_outside_compilation) { + // Case 1b: outside compilation to its XLA computation control edge. + ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); + } else if (!src_outside_compilation && dst_outside_compilation) { + // Case 1b: XLA computation to outside compilation in it control edge. + ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); + } + } + + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + return Status::OK(); +} + +// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges between outside compilation and host computation. Notice that + // we do not store `Edge*` directly because we remove some nodes while adding + // Identity nodes, and those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation && + *src_outside_compilation != *dst_outside_compilation) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); + VLOG(4) << "Oc -> oc edge: " << e->DebugString(); + } + } + + // Remove the edge from host to outside compilation. Add a placeholder as + // outside compilation node input. + std::map placeholders; + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Find or create placeholder node. + string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder"); + auto iter = placeholders.find(new_name); + Node* placeholder_node; + if (iter == placeholders.end()) { + NodeDefBuilder placeholder_builder(new_name, "Placeholder"); + placeholder_builder.Attr("dtype", src->output_type(src_output)); + string outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), + outside_compilation_attr_name, + &outside_compilation_attr)); + placeholder_builder.Attr(outside_compilation_attr_name, + outside_compilation_attr); + placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName, + src_output); + NodeDef placeholder_def; + TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); + Status s; + placeholder_node = g->AddNode(placeholder_def, &s); + TF_RETURN_IF_ERROR(s); + placeholders[new_name] = placeholder_node; + } else { + placeholder_node = iter->second; + } + g->AddEdge(placeholder_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = placeholder_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with + // corresponding edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather all outside compilation to outside compilation nodes. + std::vector placeholder_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Placeholder" && + HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) { + placeholder_nodes.push_back(n); + } + } + + // Remove the placeholder nodes, and reconnect original edge. + auto node_name_index = g->BuildNodeNameIndex(); + for (auto n : placeholder_nodes) { + string node_name; + int node_src_output; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output)); + auto iter = node_name_index.find(node_name); + if (iter == node_name_index.end()) { + return errors::Internal( + "Cannot find original node for oc -> host placeholder node ", + node_name); + } + + // Change all usage node to use the original node instead. + Node* original_node = iter->second; + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(original_node, e->dst()); + g->RemoveEdge(e); + } + for (int i = 0; i < data_edges.size(); i++) { + Node* dst = data_edges[i].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[i].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(original_node->name(), ":", node_src_output); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(original_node, node_src_output, replace_node, dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int j = i + 1; j < data_edges.size(); j++) { + if (data_edges[j].dst == dst) { + data_edges[j].dst = replace_node; + } + } + + // Other placeholder node might have `dst` as original node. Update + // `node_name_index` with `replace_node`. + node_name_index[replace_node->name()] = replace_node; + } + + // Remove placeholder node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + auto node_name_index = g->BuildNodeNameIndex(); + + // Reconnect outside compilation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, + &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); + for (const string& control_input : control_deps) { + auto iter = node_name_index.find(control_input); + if (iter == node_name_index.end()) { + return errors::Internal("Cannot find original node for ", + control_input); + } + g->AddControlEdge(iter->second, n); + } + } + } + return Status::OK(); +} } // namespace const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; -const char kXlaConnectedToXlaComputationAttrName[] = - "_xla_connected_to_xla_computation"; -const char kXlaConnectedFromXlaComputationAttrName[] = - "_xla_connected_from_xla_computation"; const char kXlaConnectedToOtherXlaComputationAttrName[] = "_xla_connected_to_other_xla_computation"; const char kXlaConnectedFromOtherXlaComputationAttrName[] = @@ -616,6 +822,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] = "_xla_host_to_oc_node_name"; const char kHostToOutsideCompilationSrcOutputAttrName[] = "_xla_host_to_oc_src_output"; +const char kXlaConnectedToXlaComputationAttrName[] = + "_xla_connected_to_xla_computation"; +const char kXlaConnectedFromXlaComputationAttrName[] = + "_xla_connected_from_xla_computation"; +const char kOutsideCompilationOriginalNodeAttrName[] = + "_xla_oc_to_oc_node_name"; +const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; +const char kXlaControlDependenciesWithinXlaClusterAttrName[] = + "_xla_control_dependencies_within_xla_cluster"; Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, @@ -699,4 +914,39 @@ Status PostprocessForEncapsulation( return Status::OK(); } +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Remove edges from source node to outside compilation nodes, and edges + // from outside compilation nodes to sink node. + std::vector edges_to_remove; + for (const Edge* e : g->source_node()->out_edges()) { + if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (const Edge* e : g->sink_node()->in_edges()) { + if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + + TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index a3b193eea745d4e44781225130216253c19371da..e363bc5754ac395bae262dc67a780a0173efaf5e 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name); -// Attribute indicating that some ops in this node's XLA computation has control -// dependency on this node. Attribute value will always be "true". -extern const char kXlaConnectedToXlaComputationAttrName[]; - -// Attribute indicating that this node has control dependency on some ops in -// this node's XLA computation. Attribute value will always be "true". -extern const char kXlaConnectedFromXlaComputationAttrName[]; - // Attribute indicating that some ops in other XLA computation has control // dependency on this node. Attribute value will be a list of string (XLA // computation names). @@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; // int (src_output for original edge). extern const char kOutsideCompilationToHostSrcOutputAttrName[]; +// Attribute indicating that some ops in this node's XLA computation has control +// dependency on this node. Attribute value will always be "true". +extern const char kXlaConnectedToXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// this node's XLA computation. Attribute value will always be "true". +extern const char kXlaConnectedFromXlaComputationAttrName[]; + // Attribute indicating that this is an Placeholder node added to act as a // temporary input node for an host node. Attribute value will be string // (original input node name). @@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; // for original edge). extern const char kHostToOutsideCompilationSrcOutputAttrName[]; -// Preprocesses the graph for encapsulation. It will perform the following -// operations in order: +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes within the same XLA cluster. Attribute value will be a list of string +// (node names). +extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; + +// Preprocesses edges between different XLA clusters for encapsulation. It will +// perform the following operations in order: // -// 1a. For control edges between outside compilation and its XLA computation, -// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the -// outside compilation node. -// 1b. For control edges between outside compilation and another XLA +// 1a. For control edges between outside compilation and another XLA // computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName // = XLA computation node name" to the outside compilation node. -// 1c. For control edges between different outside compilations, remove the edge -// and add attr "kXlaControlDependenciesAttrName = src node name" to dst -// node. -// 1d. For control edges between outside compilation and host computation, +// 1b. For control edges between different outside compilations (in different +// XLA computations), remove the edge and add attr +// "kXlaControlDependenciesAttrName = src node name" to dst node. +// 1c. For control edges between outside compilation and host computation, // remove the edge and add attr "kXlaControlDependenciesAttrName = src node // name" to dst node. // 2. For data edges between different XLA computations, if either src or dst @@ -117,6 +129,25 @@ Status PreprocessForEncapsulation(Graph* g, // Information for XLA computation. struct XlaClusterInfo { + // Add an explicitly-defined default constructor for this class. + // + // The compiler may delete the default constructor here because + // host_compute_core is a const member whose type (std::map) doesn't + // necessarily have a user provided constructor -- while libc++ and + // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at + // least >= 7.3 does not. See also c++11 [class.ctor] p5. + // + // TODO(klimek): In c++17 we'll be able to initialize host_compute_core + // without losing aggregate initialization, which allows us to get rid of + // the constructor definitions again. + XlaClusterInfo() {} + XlaClusterInfo(const string& cluster_name, + const NameAttrList& func_name_attrs, Node* node, + const std::map& host_compute_core) + : cluster_name(cluster_name), + func_name_attrs(func_name_attrs), + node(node), + host_compute_core(host_compute_core) {} // XLA cluster name. It might be different from `func_name`. const string cluster_name; // Name and attributes of XLA computation function. @@ -127,26 +158,53 @@ struct XlaClusterInfo { const std::map host_compute_core; }; -// Postprocesses the graph for encapsulation. This function reverts what -// `PreprocessForEncapsulation` did. It will perform the following operations in -// order: +// Postprocesses edges between different XLA clusters for encapsulation. This +// function reverts what `PreprocessForEncapsulation` did. It will perform the +// following operations in order: // // 1. Remove Placeholder nodes between outside compilation and host computation // (created in `PreprocessForEncapsulation` step 3). // 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. -// 3a. Reconnect control edges between different outside compilations (marked by -// `PreprocessForEncapsulation` step 1c) and control edges between outside -// compilation and host computation (marked by `PreprocessForEncapsulation` -// step 1d). -// 3b. Reconnect control edges between outside compilation and another XLA -// computation (marked by `PreprocessForEncapsulation` step 1b). -// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are -// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`. +// 3a. Reconnect control edges between outside compilation and another XLA +// computation (marked by `PreprocessForEncapsulation` step 1a). +// 3b. Reconnect control edges between different outside compilations (marked by +// `PreprocessForEncapsulation` step 1b). +// 3c. Reconnect control edges between outside compilation and host computation +// (marked by `PreprocessForEncapsulation` step 1c). Status PostprocessForEncapsulation( Graph* g, const string& xla_computation_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters); +// Preprocesses edges within the same XLA cluster. It will perform the following +// operations in order: +// +// 0. Remove edges from source node to outside compilation nodes, and edges +// from outside compilation nodes to sink node. +// 1a. For edges between different outside compilation clusters, remove the edge +// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node +// name" to dst node. +// 1b. For control edges between outside compilation and its XLA computation, +// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the +// outside compilation node. +// 2. For data edges between different outside compilations, remove the edge +// and create a Placeholder node as dst node's input. +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); + +// Postprocesses edges within the same XLA cluster. This function reverts what +// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between different outside compilations (created +// in `PreprocessEdgesBetweenOutsideCompilations` step 2). +// 2a. Reconnect control edges between different outside compilations (marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1a). +// Notice that control edges marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. +// They are handled in `RewriteOutsideCompilationSubgraphFn`. +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc index 7255df3112916b7abcc98ff8204efc8c02209b13..25c32cef01d7f9877a35001457539f2ad189192f 100644 --- a/tensorflow/compiler/jit/encapsulate_util_test.cc +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { identity4_node->AddAttr("_xla", "1"); identity4_node->AddAttr("_oc", "0"); identity5_node->AddAttr("_xla", "1"); - // Case 1a: control edges between outside compilation and its XLA computation. - g.AddControlEdge(add_node, identity0_node); - g.AddControlEdge(identity0_node, identity1_node); - // Case 1b: control edges between outside compilation and another XLA + // Case 1a: control edges between outside compilation and another XLA // computation. g.AddControlEdge(identity0_node, identity3_node); g.AddControlEdge(identity1_node, identity4_node); - // Case 1c: control edges between different outside compilations. + // Case 1b: control edges between different outside compilations. g.AddControlEdge(identity0_node, identity4_node); - // Case 1d: control edges between outside compilation and host computation. + // Case 1c: control edges between outside compilation and host computation. g.AddControlEdge(const0_node, identity0_node); g.AddControlEdge(identity0_node, identity2_node); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); - // Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the - // outside compilation node. - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedFromXlaComputationAttrName)); - EXPECT_TRUE(HasNodeAttr(identity0_node->def(), - kXlaConnectedToXlaComputationAttrName)); - // Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name" + // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" // to the outside compilation node. std::vector attr; TF_CHECK_OK(GetNodeAttr(identity0_node->def(), @@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) { kXlaConnectedFromOtherXlaComputationAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "0"); - // Case 1c: add attr "_xla_control_deps = src node name" to dst node. + // Case 1b: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity4_node->def(), kXlaControlDependenciesAttrName, &attr)); EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr[0], "identity0"); - // Case 1d: add attr "_xla_control_deps = src node name" to dst node. + // Case 1c: add attr "_xla_control_deps = src node name" to dst node. attr.clear(); TF_CHECK_OK(GetNodeAttr(identity0_node->def(), kXlaControlDependenciesAttrName, &attr)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 70b019d35fc80c975bc23ef42d61e3e36e4d0924..e3c7e2f89be9b37b51a633dabb099969c181013f 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode( // replace this node with compilation result node. // 3) all outside compilation graphs. Status ConstructHostGraph( - const string& xla_cluster_name, + const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { host_graph->reset(new Graph(fld)); @@ -394,12 +394,12 @@ Status ConstructHostGraph( for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; FunctionBody* host_fbody = nullptr; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fld->Find(host_func), AttrSlice(), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &host_fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_func), AttrSlice(), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &host_fbody)); std::unique_ptr host_fbody_deleter(host_fbody); // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse @@ -411,52 +411,53 @@ Status ConstructHostGraph( node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); Status s; - ReverseDFS(*host_fbody->graph, /*enter=*/nullptr, - [&](const Node* n) { - if (!s.ok()) { - return; - } - - Node* copy; - if (node_map.find(n) != node_map.end()) { - // Already copied this node. - copy = node_map.at(n); - } else if (IsKeyPlaceholderNode(*n)) { - // Change a). - copy = key_placeholder; - node_map[n] = copy; - } else { - // Copy the node. - NodeDef copy_def = n->def(); - // Change c). - copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); - if (!s.ok()) { - return; - } - node_map[n] = copy; - } - - // Only handle input edges. Output edges will be added later as - // its output nodes' input edges. - for (auto e : n->in_edges()) { - if (node_map.find(e->src()) == node_map.end()) { - s = errors::Internal("Cannot find node image for ", - e->src()->DebugString()); - return; - } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); - } - - // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); - } - }, - NodeComparatorID()); + ReverseDFS( + *host_fbody->graph, /*enter=*/nullptr, + [&](const Node* n) { + if (!s.ok()) { + return; + } + + Node* copy; + if (node_map.find(n) != node_map.end()) { + // Already copied this node. + copy = node_map.at(n); + } else if (IsKeyPlaceholderNode(*n)) { + // Change a). + copy = key_placeholder; + node_map[n] = copy; + } else { + // Copy the node. + NodeDef copy_def = n->def(); + // Change c). + copy_def.clear_device(); + copy = (*host_graph)->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + node_map[n] = copy; + } + + // Only handle input edges. Output edges will be added later as + // its output nodes' input edges. + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + (*host_graph) + ->AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); + } + + // Change b). + if (copy->type_string() == "_XlaRecvAtHost" || + copy->type_string() == "_XlaSendFromHost") { + (*host_graph)->AddControlEdge(copy, sequencer); + } + }, + NodeComparatorID()); if (!s.ok()) { return s; } @@ -475,6 +476,10 @@ Status ConstructHostGraph( host_graph->get(), std::unordered_set{(*host_graph)->sink_node()}); + // Postprocess edges between different outside compilations. + TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( + host_graph->get(), outside_compilation_attr_name)); + if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_host_graph_for_", @@ -800,6 +805,11 @@ Status ExtractOutsideCompilationForFunction( }, &fbody)); std::unique_ptr fbody_deleter(fbody); + + // Preprocess edges between different outside compilations. They will be + // restored in `ConstructHostGraph()`. + TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( + fbody->graph, outside_compilation_attr_name)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("extract_outside_compilation_for_func_before_", func_name), @@ -838,7 +848,12 @@ Status ExtractOutsideCompilationForFunction( FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( shape_inference_graph); - TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + if (fld->Find(shape_inference_graph)) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + } } } } @@ -854,8 +869,9 @@ Status ExtractOutsideCompilationForFunction( // Construct host graph. if (!outside_compilation_host_graphs.empty()) { - TF_RETURN_IF_ERROR(ConstructHostGraph( - xla_cluster_name, outside_compilation_host_graphs, fld, host_graph)); + TF_RETURN_IF_ERROR( + ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph)); } // Remove the outside compilation graphs from function library. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index c5bd64f004ef98853955372680277e04c16bdc9e..bff956100da661b679b4557fce53671e6cef88c5 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes)); EXPECT_EQ(shapes.size(), 1); EXPECT_EQ(shapes[0].dim_size(), 1); - // Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a - // non-empty value, and "1" should have an empty value. + // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have + // empty values. string shape_inference_graph; TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", &shape_inference_graph)); - EXPECT_EQ(shape_inference_graph, - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graph, ""); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", &shape_inference_graph)); EXPECT_EQ(shape_inference_graph, ""); // Check `shape_inference_graphs`. - EXPECT_EQ(shape_inference_graphs.size(), 1); - EXPECT_EQ(shape_inference_graphs[0], - "_outside_compilation_shape_inference_cluster_0"); + EXPECT_EQ(shape_inference_graphs.size(), 0); // Check `host_graph`: verify we have key placeholder and sequencer. Node *key_placeholder = nullptr, *sequencer = nullptr; @@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) { send_recv_nodes.push_back(n); } } - EXPECT_EQ(num_send_from_host, 2); - EXPECT_EQ(num_recv_at_host, 2); + EXPECT_EQ(num_send_from_host, 1); + EXPECT_EQ(num_recv_at_host, 1); for (Node *n : send_recv_nodes) { Node *input_node; TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..98e344b3a080aa8aab27cd41564a90427bac151e --- /dev/null +++ b/tensorflow/compiler/jit/flags.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +BuildXlaOpsPassFlags* build_ops_flags; +DumpGraphFlags* dump_graph_flags; +MarkForCompilationPassFlags* mark_for_compilation_flags; +XlaDeviceFlags* device_flags; +XlaOpsCommonFlags* ops_flags; + +std::vector* flag_list; +std::once_flag flags_init; + +void AppendDumpGraphFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, + "Path prefix to which graphs dumped during debugging should be " + "written."), + }; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", + &mark_for_compilation_flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", + &mark_for_compilation_flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", + &mark_for_compilation_flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", + &mark_for_compilation_flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", + &mark_for_compilation_flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", + &mark_for_compilation_flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AllocateAndParseFlags() { + build_ops_flags = new BuildXlaOpsPassFlags; + build_ops_flags->tf_xla_enable_lazy_compilation = true; + + dump_graph_flags = new DumpGraphFlags; + dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; + + mark_for_compilation_flags = new MarkForCompilationPassFlags; + mark_for_compilation_flags->tf_xla_auto_jit = 0; + mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_max_cluster_size = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_clustering_debug = false; + mark_for_compilation_flags->tf_xla_cpu_global_jit = false; + mark_for_compilation_flags->tf_xla_clustering_fuel = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_fusion_only = false; + + device_flags = new XlaDeviceFlags; + device_flags->tf_xla_compile_on_demand = false; + + ops_flags = new XlaOpsCommonFlags; + ops_flags->tf_xla_always_defer_compilation = false; + + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + + Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + + Flag("tf_xla_always_defer_compilation", + &ops_flags->tf_xla_always_defer_compilation, ""), + }); + AppendDumpGraphFlagsInternal(flag_list); + AppendMarkForCompilationPassFlagsInternal(flag_list); + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *build_ops_flags; +} + +DumpGraphFlags* GetDumpGraphFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return dump_graph_flags; +} + +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return mark_for_compilation_flags; +} + +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return device_flags; +} + +const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *ops_flags; +} + +void AppendMarkForCompilationPassFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendMarkForCompilationPassFlagsInternal(flag_list); +} + +void AppendDumpGraphFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendDumpGraphFlagsInternal(flag_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/flags.h similarity index 56% rename from tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h rename to tensorflow/compiler/jit/flags.h index 2affda6ab4e0fbad32a246744fa5b38aeb629c1b..5ddea588eef5270880d91623dc05893da265960a 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -13,10 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. +#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_FLAGS_H_ #include @@ -24,16 +22,9 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags( - std::vector* flag_list); -// The values of flags associated with the XLA bridge's -// mark_for_compilation_pass module. -typedef struct { +// Flags associated with the XLA bridge's mark_for_compilation_pass module. +struct MarkForCompilationPassFlags { int32 tf_xla_auto_jit; // Control compilation of operators into XLA // computations on CPU and GPU devices. 0 = use // ConfigProto setting; -1 = off; 1 = on for things @@ -55,14 +46,58 @@ typedef struct { // is set to ON* and overrides its behavior. If // true, enable fusion of element-wise operations // only using XLA. -} MarkForCompilationPassFlags; +}; + +// Flags associated with the XLA bridge's xla_device module. +struct XlaDeviceFlags { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +}; + +// Flags common to the _Xla* ops and their kernels. +struct XlaOpsCommonFlags { + // If true, _XlaCompile always refuses to compile the cluster, which means the + // XLA clusters always run in the TF executor. Defaults to false. + bool tf_xla_always_defer_compilation; +}; -// Return a pointer to the MarkForCompilationPassFlags struct; +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to true. + bool tf_xla_enable_lazy_compilation; +}; + +// Flags for the XLA bridge's dump_graph module. +struct DumpGraphFlags { + // Path prefix to which graphs dumped during debugging should be written. + string tf_dump_graph_prefix; +}; + +// Return a pointer to the DumpGraphFlags struct; // repeated calls return the same pointer. // This should be called only after Flags::Parse() has returned. + +// Getters for flags structs defined above. The first call to any of these +// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer +// always return the same pointer. MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); +const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); +DumpGraphFlags* GetDumpGraphFlags(); + +// Appends the flag definitions associated with +// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. +// +// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); +void AppendDumpGraphFlags(std::vector* flag_list); -} // namespace legacy_flags } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index bd8719b7f1acb79e0b0cd91f2f0de0d66d8dab46..ce53f70b79d97ab087fefe542920b33f883632a2 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -18,11 +18,12 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -34,14 +35,30 @@ limitations under the License. namespace tensorflow { namespace { -Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) { - TF_RET_CHECK(n->type_string() == "Const"); + +// StatusOrOptional instances hold +// +// - A non-OK Status to indicate an error that needs to be propagated out of +// this pass (e.g. the Graph is malformed). +// +// - A nullopt to indicate the function that created the instance failed to do +// what it set out to do but this is not actually an error +// (e.g. TryToGetTensorFromConstOp was passed a non-Const node). +// +// - A T to indicate a successful operation. +template +using StatusOrOptional = xla::StatusOr>; + +StatusOrOptional TryToGetTensorFromConstOp(Node* n) { + if (n->type_string() != "Const") { + return {absl::nullopt}; + } + const TensorProto* proto = nullptr; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); Tensor tensor(proto->dtype()); TF_RET_CHECK(tensor.FromProto(*proto)); - *out_tensor = std::move(tensor); - return Status::OK(); + return {tensor}; } struct SliceInputs { @@ -70,7 +87,7 @@ std::vector IntTensorAsVector(const Tensor& t) { // Packages up the inputs to a Slice operation into an instance of // `SliceInputs`. -Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { +StatusOrOptional GetSliceInputs(Node* slice) { const int kSliceInputIndex = 0; const int kSliceBeginIndex = 1; const int kSliceSizeIndex = 2; @@ -81,23 +98,27 @@ Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); const Edge* slice_begin_edge; TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); - slice_inputs->input = + + SliceInputs slice_inputs; + slice_inputs.input = Output(slice_input_edge->src(), slice_input_edge->src_output()); - slice_inputs->begin = + slice_inputs.begin = Output(slice_begin_edge->src(), slice_begin_edge->src_output()); - slice_inputs->size = + slice_inputs.size = Output(slice_size_edge->src(), slice_size_edge->src_output()); - Tensor tf_slice_size; - TF_RETURN_IF_ERROR( - GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size)); + TF_ASSIGN_OR_RETURN(absl::optional tf_slice_size, + TryToGetTensorFromConstOp(slice_inputs.size.node())); + if (!tf_slice_size.has_value()) { + return {absl::nullopt}; + } - if (tf_slice_size.dims() != 1) { - return errors::Internal("Expected vector for the slice size input."); + if (tf_slice_size->dims() != 1) { + return {absl::nullopt}; } - slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size); - return Status::OK(); + slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size); + return {slice_inputs}; } // Casts `x` to a DT_INT64 if it isn't one already. @@ -187,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope, DCHECK_EQ(slice_size.back().type(), DT_INT64); } - *size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, - ops::Const(host_scope.WithOpName("concat_axis"), 0)); + // Trivial ConcatV2 nodes (with exactly one input) are disallowed. + *size = + slice_size.size() == 1 + ? slice_size[0] + : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + ops::Const(host_scope.WithOpName("concat_axis"), 0)); return Status::OK(); } @@ -221,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( .WithOpName("static_shaped_slice"), slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) .node(); + + TF_RETURN_IF_ERROR(main_scope.status()); + std::vector compile_time_const_inputs; compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, @@ -263,10 +291,9 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Returns true if `n` is a slice we can rewrite to have a static shape -// (i.e. have the output shape only depend on the "size" input). Fills in -// `slice_inputs` in the process. -bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { +// Return true if `n` is a slice we can rewrite to have a static shape +// (i.e. have the output shape only depend on the "size" input). +xla::StatusOr IsRewritableSlice(Node* n) { if (n->type_string() != "Slice") { return false; } @@ -276,8 +303,9 @@ bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { return false; } - if (!GetSliceInputs(n, slice_inputs).ok()) { - // Could not parse slice inputs. E.g. the sizes input was not a constant. + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + if (!slice_inputs.has_value()) { return false; } @@ -288,17 +316,20 @@ bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { } Status FindAndRewriteSlices(Graph* g, bool* changed) { - std::vector> slices_to_rewrite; + std::vector slices_to_rewrite; for (Node* n : g->nodes()) { - SliceInputs slice_inputs; - if (IsRewritableSlice(n, &slice_inputs)) { - slices_to_rewrite.push_back({n, std::move(slice_inputs)}); + TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + if (is_rewritable) { + slices_to_rewrite.push_back(n); } } - for (const auto& pair : slices_to_rewrite) { - TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second, - *GetXlaClusterForNode(*pair.first))); + for (Node* n : slices_to_rewrite) { + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + TF_RET_CHECK(slice_inputs.has_value()); + TF_RETURN_IF_ERROR( + RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n))); } if (!slices_to_rewrite.empty()) { @@ -314,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", **options.graph, options.flib_def); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 0f6f612e967035f6af3e4aff2a499d5cedd018af..a2f1b831ad7605237e23c15cc43b337e06265553 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. namespace tensorflow { namespace { +using ::testing::_; using testing::matchers::AssignedDevice; using testing::matchers::Attr; using testing::matchers::Const; @@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { EXPECT_THAT(static_shaped_slice, m_dynamic_slice); } +TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + EXPECT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2"))))); +} + TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { Scope root = Scope::NewRootScope() .ExitOnError() @@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); } +int64 ToInt64(int v) { return static_cast(v); } + TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; - Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); Output size = - ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)}); + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)}); Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); std::unique_ptr result; @@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) { Attr(kXlaCompileTimeConstantInputsAttr))))); } +TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Const(root.WithOpName("size"), {}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr), + Inputs(_, _, Out(NodeWith(Name(size.node()->name())))))); +} + TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Scope root = Scope::NewRootScope() .ExitOnError() .WithAssignedDevice(kDeviceName) .WithXlaCluster("cluster_0"); - auto to_int64 = [](int v) { return static_cast(v); }; + auto ToInt64 = [](int v) { return static_cast(v); }; Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); @@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); Output size = - ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}}); + ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}}); TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); std::unique_ptr result; @@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { Not(Contains(NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr))))); } + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a); + + Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200}); + Output slice_with_slice_input = ops::Slice( + root.WithOpName("slice_with_slice_input"), slice, begin, size_b); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_input/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(Out(NodeWith( + Op("Slice"), + Name("slice/static_shaped_slice/static_shaped_slice"))), + _, _))); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input_float = + ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT); + Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64); + + Output begin_begin = + ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32); + Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1}); + Output begin = + ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size); + + Output size = + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)}); + Output slice_with_slice_begin = ops::Slice( + root.WithOpName("slice_with_slice_begin"), input_float, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_begin/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(_, + Out(NodeWith( + Op("Slice"), + Name("begin/static_shaped_slice/static_shaped_slice"))), + _))); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 107d521077c3fe2ac72d113d46e2566c78c9fafb..f79bdc1e2e8d82c9144d1bb9923ad36d8541cbdb 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -44,11 +44,8 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); -// TODO(b/111210515): IncreaseDynamismForAutoJitPass creates slices with index -// type DT_INT64 which do not have a kernel on GPU. -// -// REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, -// IncreaseDynamismForAutoJitPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + IncreaseDynamismForAutoJitPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 830db9ebdd92608c375ad778eced833e26729325..0583774714c6db7a2fa515fc8a0d304e1898db97 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -12,10 +12,10 @@ cc_library( hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", - "//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 6bcae1dcc3dcf87faa5317e0064c4c0cf80af465..ad71df5a694a5f8da94675049df1062a7edb6253 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -39,12 +39,22 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + namespace tensorflow { namespace { -Status PlatformInfoFromContext(OpKernelConstruction* ctx, - XlaPlatformInfo* result) { +XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; @@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx, } if (!device_allocator) { - TF_ASSIGN_OR_RETURN(se::Platform* const platform, - se::MultiPlatformManager::PlatformWithId(platform_id)); + xla::StatusOr maybe_platform = + se::MultiPlatformManager::PlatformWithId(platform_id); + OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); + xla_allocator = absl::make_unique( - platform, ctx->device()->GetAllocator({})); + maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } - *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); - - return Status::OK(); + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, : OpKernel(ctx), constants_(constants), resources_(resources), - function_(function) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(function), + platform_info_(PlatformInfoFromContext(ctx)) {} static Status BuildCompilationCache(OpKernelContext* ctx, const XlaPlatformInfo& platform_info, @@ -277,8 +286,10 @@ static Status CompileToLocalExecutable( // rather than a one-element tuple. compile_options.always_return_tuple = false; - return cache->Compile(options, function, constant_args, *variables, ctx, - compile_options, + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_args, *variables, ctx, &args)); + return cache->Compile(options, function, args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, kernel, executable); @@ -333,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - // Helper static functions to construct parameters for // XlaLocalLaunchBase constructor from OpKernelConstruction. std::vector ConstantsVector(OpKernelConstruction* ctx) { @@ -381,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) { return *func; } -#undef OP_REQUIRES_OK_RETURN +bool MustCompileAttr(OpKernelConstruction* ctx) { + bool must_compile; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr("must_compile", &must_compile)); + return must_compile; +} } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -396,10 +400,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx), constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), - function_(FunctionAttr(ctx)) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); -} + function_(FunctionAttr(ctx)), + platform_info_(PlatformInfoFromContext(ctx)), + must_compile_(MustCompileAttr(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() @@ -409,13 +412,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) { + bool cannot_compile_cluster; + { + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster = cannot_compile_cluster_; + } + + if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + cannot_compile_cluster) { executable = nullptr; } else { - OP_REQUIRES_OK(ctx, CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/!must_compile_, &client, - &variables, &kernel, &executable)); + Status status = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, + /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + OP_REQUIRES_OK(ctx, status); + } + + if (status.code() == error::UNIMPLEMENTED) { + LOG(WARNING) << "Compilation failed:" << status.ToString() + << ". Falling back to TF function call."; + executable = nullptr; + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster_ = true; + } } AllocatorAttributes host_alloc_attrs; @@ -452,9 +472,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(1, compilation_successful); } -XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index ac90837e0d90943b93e2cdb01a30fa0837ba94df..7b4d4b5b4737784d4fe277d5bbe9cab79cfaf4c9 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#include + #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" @@ -33,6 +35,7 @@ namespace tensorflow { class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; explicit XlaPlatformInfo(const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, @@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel { protected: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; - XlaPlatformInfo platform_info_; + const NameAttrList function_; + const XlaPlatformInfo platform_info_; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel { private: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; + const NameAttrList function_; XlaPlatformInfo platform_info_; - bool must_compile_; + const bool must_compile_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + + mutex cannot_compile_cluster_mu_; }; class XlaRunOp : public OpKernel { @@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaPlatformInfo platform_info_; + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD deleted file mode 100644 index 49ff9a3ddd1fc14ba59209c39e00856986deab2d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -# Legacy command line flags for the XLA bridge libraries. - -# Please do not add more flags to this package. - -# The XLA bridge libraries were written in an environment that allowed -# command-line flags to be scattered freely throughout the libraries. This -# model, while initially convenient, leads to a proliferation in unused command -# line flags in tests and binaries, and serious problems in servers, where one -# might wish parameters to be different in independent RPC calls to the same -# routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -cc_library( - name = "mark_for_compilation_pass_flags", - srcs = ["mark_for_compilation_pass_flags.cc"], - hdrs = ["mark_for_compilation_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_device_flags", - srcs = ["xla_device_flags.cc"], - hdrs = ["xla_device_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "build_xla_ops_pass_flags", - srcs = ["build_xla_ops_pass_flags.cc"], - hdrs = ["build_xla_ops_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_ops_common_flags", - srcs = ["xla_ops_common_flags.cc"], - hdrs = ["xla_ops_common_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc deleted file mode 100644 index 73f4dc73ed83e2d1e89ccd6c99970d46b5767104..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT - -#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { -namespace { - -BuildXlaOpsPassFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new BuildXlaOpsPassFlags; - flags->tf_xla_enable_lazy_compilation = true; - flag_list = new std::vector({ - Flag("tf_xla_enable_lazy_compilation", - &flags->tf_xla_enable_lazy_compilation, ""), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -} // namespace - -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h deleted file mode 100644 index 9aa5cf64d6db56ae36875ca08d2ae88c73604733..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags for the build_xla_ops pass. -struct BuildXlaOpsPassFlags { - // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. - // Defaults to true. - bool tf_xla_enable_lazy_compilation; -}; - -// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc deleted file mode 100644 index 7277a1d1f8ad5fa045645ead839ab9efa01e89c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static MarkForCompilationPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new MarkForCompilationPassFlags; - flags->tf_xla_auto_jit = 0; - flags->tf_xla_min_cluster_size = 2; - flags->tf_xla_max_cluster_size = std::numeric_limits::max(); - flags->tf_xla_clustering_debug = false; - flags->tf_xla_cpu_global_jit = false; - flags->tf_xla_clustering_fuel = std::numeric_limits::max(); - flags->tf_xla_fusion_only = false; - flag_list = new std::vector( - {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions."), - Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, - "Places an artificial limit on the number of ops marked as " - "eligible for clustering."), - Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, - "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the MarkForCompilationPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc deleted file mode 100644 index 1bb2fce2dbad5bffce2e33b665b7222090d0855a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's xla_device module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static XlaDeviceFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new XlaDeviceFlags; - flags->tf_xla_compile_on_demand = false; - flag_list = new std::vector({ - Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, - "Switch a device into 'on-demand' mode, where instead of " - "autoclustering ops are compiled one by one just-in-time."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h deleted file mode 100644 index 27b22121ac1e089bd5d5a494e1e3fb60b05bc76d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ - -// Legacy flags for the XLA bridge's xla_device module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// The values of flags associated with the XLA bridge's -// xla_device module. -typedef struct { - // Switch the CPU device into "on-demand" mode, where instead of - // autoclustering ops are compiled one by one just-in-time. - // Enabling this mode by a legacy flag is a temporary mechanism. When this - // feature is battle-tested, we will switch this to be a session option. - bool tf_xla_compile_on_demand; -} XlaDeviceFlags; - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc deleted file mode 100644 index ae17fdffb9b6a574449b7f3155e050b029702db7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include // NOLINT -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" - -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -XlaOpsCommonFlags* flags; -std::vector* flag_list; -std::once_flag flags_init; - -void AllocateAndParseFlags() { - flags = new XlaOpsCommonFlags; - flags->tf_xla_always_defer_compilation = false; - flag_list = new std::vector({ - Flag("tf_xla_always_defer_compilation", - &flags->tf_xla_always_defer_compilation, ""), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { - std::call_once(flags_init, &AllocateAndParseFlags); - return *flags; -} -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h deleted file mode 100644 index 7c5c1818ef2d1dcf38c324a2c926db9c4bfa8ef5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ - -namespace tensorflow { -namespace legacy_flags { - -// Flags common to the _Xla* ops and their kernels. -struct XlaOpsCommonFlags { - // If true, _XlaCompile always refuses to compile the cluster, which means the - // XLA clusters always run in the TF executor. Defaults to false. - bool tf_xla_always_defer_compilation; -}; - -// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment -// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS -// only the first time this routine is called. -const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11975a6bb07e03dc3d182beb3748eb2559de7e25..25796435a5c87af5e252981abf96833f4cda9a5e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -61,14 +61,40 @@ struct OperationFilter { // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid // auto-clustering stateful RNG ops. bool allow_stateful_rng_ops; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger; + + // Whether ops with dummy implementations are allowed. We avoid + // auto-clustering these ops so that the user is not surprised when XLA is + // implicitly enabled. If the user explicitly specifies to use XLA, it is fine + // to resort to a dummy implementation. Currently Assert and CheckNumerics ops + // have dummy XLA implementations. + bool allow_dummy_ops; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant; }; +bool IsDummyImplOp(absl::string_view op_name) { + return op_name == "Assert" || op_name == "CheckNumerics"; +} + bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || op_name == "TruncatedNormal"; } +bool OpProducesOrConsumesVariant(const Node& node) { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); +} + bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -225,6 +251,16 @@ bool IsCompilableCall(const NodeDef& call_def, IsStatefulRandomOp(node->type_string())) { return false; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + return false; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + return false; + } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, lib_runtime)) { @@ -406,8 +442,7 @@ Status FindCompilationCandidates( BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, &compile_time_const_nodes)); - int64& fuel = - legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the @@ -450,9 +485,15 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); + bool always_auto_cluster = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + OperationFilter op_filter; op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = registration->requires_compilation; + op_filter.allow_stateful_rng_ops = always_auto_cluster; + op_filter.allow_control_trigger = always_auto_cluster; + op_filter.allow_dummy_ops = always_auto_cluster; + op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster; if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, @@ -467,6 +508,21 @@ Status FindCompilationCandidates( VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; continue; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op"; + continue; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": dummy op (" + << node->type_string() << ")"; + continue; + } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + VLOG(2) << "Rejecting " << node->name() + << ": produces or consumes DT_VARIANT"; + continue; + } if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { @@ -570,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( // To set compilation to be on by default, change the following line. global_jit_level = OptimizerOptions::OFF; } - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_auto_jit == -1 || (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides @@ -597,11 +652,15 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - // We can always *compile* resource operations and stateful RNGs, even if we - // are sometimes unable to auto-cluster them. + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. OperationFilter op_filter; op_filter.allow_resource_ops = true; op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_dummy_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } @@ -611,12 +670,9 @@ Status MarkForCompilationPass::Run( // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); - bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool fusion_only = flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; @@ -635,9 +691,6 @@ Status MarkForCompilationPass::Run( return false; } - // If this device requires a JIT, we must say yes. - if (registration->requires_compilation) return true; - // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); @@ -674,18 +727,21 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level. - // Ignore enable_jit_by_default if global jit compilation for CPU - // is explicitly requested via tf_xla_cpu_global_jit flag - bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + // Otherwise use the value of global_jit_level and the device's + // autoclustering policy. bool should_compile = - (ignore_registration || registration->enable_jit_by_default) && - global_jit_level != OptimizerOptions::OFF; + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); if (!should_compile) { if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { - VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + VLOG(2) + << "Rejecting " << node->name() + << ": autoclustering for device only when requested explicitly."; } } return should_compile; @@ -915,8 +971,7 @@ Status MarkForCompilationPass::RunImpl( OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. @@ -1073,12 +1128,10 @@ Status MarkForCompilationPass::RunImpl( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if the operator is placed on a device that requires - // compilation, or if it contains at least one op that is marked for + // Also, always compile if it contains at least one op that is marked for // compilation that is not an Identity op. if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || - registration->requires_compilation) { + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { string& name = cluster_names[cluster]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index ead1cf4fd5faff649e8518aaeb95935ccef4ca52..bf2c5508ea9e987e80093f4c2e15d3ff5191126f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -817,14 +818,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unordered_map clusters = GetClusters(*graph); - ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; - - // ctrl_trigger_a has inputs with mismatching deadness so it won't be - // clustered. ctrl_trigger_b is okay to cluster. - std::unordered_map expected_clusters( - {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); - EXPECT_EQ(clusters, expected_clusters); + // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so + // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't + // cluster it because of b/118970344. + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, RandomShape) { @@ -923,9 +920,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); - EXPECT_NE(clusters["test/shape_rng"], ""); - EXPECT_NE(clusters["test/reshape"], ""); - EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); + EXPECT_EQ(clusters["test/shape_rng"], ""); + EXPECT_EQ(clusters["test/reshape"], ""); } TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { @@ -1088,7 +1084,7 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { EXPECT_NE(clusters["test/c"], ""); } -TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { +TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { Scope root = Scope::NewRootScope().ExitOnError(); Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); @@ -1104,5 +1100,128 @@ TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/b"], ""); } + +TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/check"], ""); + EXPECT_NE(clusters["test/greaterequal"], ""); + EXPECT_NE(clusters["test/assert"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterDummyOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/assert"], ""); + EXPECT_EQ(clusters["test/check"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output dummy_input = + ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64); + Output variant_input = + ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT); + + // Create one more node so that we don't avoid creating a cluster solely + // because it would be trivial. + Output dummy_cast = + ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32); + + Output tensor_list_element_shape = ops::TensorListElementShape( + root.WithOpName("test/tensor_list_element_shape"), variant_input, + DT_INT32); + + root.graph()->AddControlEdge(dummy_cast.node(), + tensor_list_element_shape.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); +} + +TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(xla_cpu_device); + } + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/tensor_list_reserve"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 5b9610322336acbcede0bef0538043b8ff917c16..42ea3926e16ae791dbe1bede3b8742383db7667c 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -26,6 +26,10 @@ limitations under the License. namespace tensorflow { namespace { + +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } + +namespace reduce_device_to_host_copies { Status FindNodesToDecluster(const Graph& graph, absl::flat_hash_set* result, absl::Span post_order) { @@ -133,11 +137,13 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveEdge(out_edge_to_clone); } + if (n->out_edges().empty()) { + graph->RemoveNode(n); + } + return Status::OK(); } -bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } - // Clones nodes to outside their cluster to avoid device-to-host copies. For // instance, converts this: // @@ -164,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } // where the ===> arrow has a hostmem source and destination and would entail a // device to host copy if the source and destination were not in the same XLA // cluster. -Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been @@ -191,6 +197,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { } } + // Recompute post order since PartiallyDeclusterNode may have deleted nodes. + post_order.clear(); + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); nodes_to_partially_decluster.clear(); TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); @@ -198,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { return Status::OK(); } +} // namespace reduce_device_to_host_copies +namespace reduce_recompilation { bool IsIntraClusterEdge(const Edge& edge) { absl::optional src_cluster_name = GetXlaClusterForNode(*edge.src()); @@ -210,7 +222,8 @@ bool IsIntraClusterEdge(const Edge& edge) { bool IsMustCompileDevice(const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return registration->requires_compilation; + return registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; } return false; @@ -260,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) { // regress performance in any significant manner. We will have to revisit this // algorith with a more complex cost model if this assumption turns out to be // incorrect. -Status DeclusterNodesToReduceRecompilations(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { std::vector compile_time_const_nodes(graph->num_node_ids()); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); @@ -313,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) { return Status::OK(); } - +} // namespace reduce_recompilation } // namespace Status PartiallyDeclusterPass::Run( @@ -325,8 +338,9 @@ Status PartiallyDeclusterPass::Run( Graph* graph = options.graph->get(); - TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); - TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + TF_RETURN_IF_ERROR( + reduce_device_to_host_copies::PartiallyDeclusterGraph(graph)); + TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 74d5ef57184197ad6e9e5048722e84863756a3f5..1fc5da5071f7aa6f6dd6636aacd60e33c12431a6 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -437,5 +437,32 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); } +TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { + const char* const kClusteredProducer0Name = "ClusteredProducer0"; + const char* const kClusteredProducer1Name = "ClusteredProducer1"; + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName(kClusteredProducer0Name)); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName(kClusteredProducer1Name)); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h deleted file mode 100644 index 7c8c04152d2f3a0fd46711df24756b7e68b967ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ -#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ - -#include -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// A thread-safe, first-in-first-out queue. -template -class ProducerConsumerQueue { - public: - ProducerConsumerQueue() - : capacity_(std::numeric_limits::max()) {} - ~ProducerConsumerQueue() = default; - - // Wait until the queue is non-full, then append a copy of v. - void Put(const T &v); - - // Wait until the queue is non-empty, then remove and return the head value. - T Get(); - - // If the queue is non-empty, remove the head value, placing it in *pv, and - // return true; otherwise return false. - bool TryGet(T *pv); - - // Set the capacity of the queue; the queue is full whenever count() >= - // capacity(). The initial value is the maximum size_t. Requires size > 0. - void set_capacity(std::size_t size); - - // Return the capacity of the queue. - std::size_t capacity() const; - - // Return the number of elements in the queue. - std::size_t count() const; - - // Implementation details follow. Clients should ignore. - private: - mutable tensorflow::mutex mu_; // protects all fields below - tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); - tensorflow::condition_variable non_full_ GUARDED_BY(mu_); - std::size_t capacity_ GUARDED_BY(mu_); - std::deque queue_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); -}; - -// ------------------------------------------------------ -// Implementation details follow. Clients should ignore. - -// Wait until the queue is non-full, then append a copy of v. -template -void ProducerConsumerQueue::Put(const T &v) { - mutex_lock lock(mu_); - while (queue_.size() >= capacity_) { - non_full_.wait(lock); - } - queue_.push_back(v); - non_empty_.notify_one(); -} - -// Wait until the queue is non-empty, then remove and return the head value. -template -T ProducerConsumerQueue::Get() { - mutex_lock lock(mu_); - while (queue_.empty()) { - non_empty_.wait(lock); - } - non_full_.notify_one(); - T result_value = queue_.front(); - queue_.pop_front(); - return result_value; -} - -// If the queue is non-empty, remove the head value, placing it in *pv, and -// return true; otherwise return false. -template -bool ProducerConsumerQueue::TryGet(T *pv) { - mutex_lock lock(mu_); - bool got_element = !queue_.empty(); - if (got_element) { - non_full_.notify_one(); - *pv = queue_.front(); - queue_.pop_front(); - } - return got_element; -} - -// Set the capacity of the queue; the queue is full whenever count() >= -// capacity(). The initial value is the maximum size_t. Requires size > 0. -template -void ProducerConsumerQueue::set_capacity(std::size_t size) { - mutex_lock lock(mu_); - CHECK_NE(size, 0); - capacity_ = size; - non_full_.notify_all(); -} - -// Return the capacity of the queue. -template -std::size_t ProducerConsumerQueue::capacity() const { - mutex_lock lock(mu_); - std::size_t max_elements = capacity_; - return max_elements; -} - -// Return the number of elements in the queue. -template -std::size_t ProducerConsumerQueue::count() const { - mutex_lock lock(mu_); - std::size_t num_elements = queue_.size(); - return num_elements; -} -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc deleted file mode 100644 index f61260c6e52756ee039829afdc7452f5f760c221..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/producer_consumer_queue.h" - -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -typedef ProducerConsumerQueue IntQueue; - -// Insert integers between low inclusive and high exclusive into q. -void PushRange(IntQueue *q, int low, int high) { - while (low != high) { - q->Put(low); - VLOG(2) << "Pushing " << low; - ++low; - } -} - -// Push the numbers between 0 and 999 inclusive from several threads in the -// pool. -void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { - VLOG(1) << "Adding 20-36"; - pool->Schedule([queue] { PushRange(queue, 20, 36); }); - VLOG(1) << "Adding 7-20"; - pool->Schedule([queue] { PushRange(queue, 7, 20); }); - VLOG(1) << "Adding 36-501"; - pool->Schedule([queue] { PushRange(queue, 36, 501); }); - VLOG(1) << "Adding 501-1000"; - pool->Schedule([queue] { PushRange(queue, 501, 1000); }); - VLOG(1) << "Adding 0-5"; - pool->Schedule([queue] { PushRange(queue, 0, 5); }); - VLOG(1) << "Adding 5-7"; - pool->Schedule([queue] { PushRange(queue, 5, 7); }); -} - -// Pop elements from queue using Get(). Make sure that exactly elements -// were present and their values are all integers between 0 and high-1 -// inclusive. -void GetRange(IntQueue *queue, int high) { - VLOG(1) << "Testing Wait"; - std::vector results; - for (int i = 0; i != high; ++i) { - int r = queue->Get(); - VLOG(2) << "Waited and got " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK(results[i] == i); - } -} - -// Pop elements from queue using TryGet(). Make sure that exactly -// elements were present and their values are all integers between 0 and high-1 -// inclusive. -void TryGetRange(IntQueue *queue, int high) { - std::vector results; - // Give up if we don't get all the elements back from the queue - // in 10 seconds. - int timeout = 10; - int r; - for (int i = 0; i != high; ++i) { - while (!queue->TryGet(&r)) { - if (!timeout--) { - LOG(FATAL) << "Can't find all elements in the queue"; - } - VLOG(1) << "Sleeping for a second..."; - sleep(1); - } - VLOG(2) << "Popped " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - CHECK(!queue->TryGet(&r)); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK_EQ(i, results[i]); - } -} - -const int kNumThreads = 15; - -TEST(ProducerConsumerQueue, GetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - GetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, TryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - TryGetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, ParallelGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { GetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -TEST(ProducerConsumerQueue, ParallelTryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 4a5ea9e0a5f8cf79478069931da598099ae4e716..3df5479a55e841380ca7b8cdd0add9fd17487091 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -65,14 +66,14 @@ string XlaCompilationCache::DebugString() { // Compute a string signature which encodes the shapes of the // arguments in the supplied list. -string XlaCompilationCache::SignatureDebugString(const Signature& sig) { - string result = sig.name; - for (const auto& a : sig.arg_types) { +string XlaCompilationCache::Signature::HumanString() const { + string result = name; + for (const auto& a : arg_types) { absl::StrAppend(&result, ",", DataTypeString(a.first), a.second.DebugString()); } - for (const auto& v : sig.arg_values) { + for (const auto& v : arg_values) { absl::StrAppend(&result, "; ", v.DebugString()); } return result; @@ -84,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { - if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { + if (arg_values[i].dtype() != other.arg_values[i].dtype() || + arg_values[i].shape() != other.arg_values[i].shape() || + arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { return false; } } @@ -108,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( return h; } -Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - Signature* signature) { - signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.reserve(constant_args.size()); - - signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); - - for (int i = 0; i < ctx->num_inputs(); ++i) { - if (constant_args.count(i) > 0) { - // Use the values of compile time constants in the signature. - signature->arg_values.push_back(constant_args.at(i)); - } else if (variable_args.count(i) > 0) { - const OptionalTensor& variable = variable_args.at(i); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); - } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); - } - } else { - signature->arg_types.emplace_back(ctx->input_dtype(i), - ctx->input(i).shape()); - } - } - return Status::OK(); -} - -namespace { - -// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. -Status BuildArguments(const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); - - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; - if (constant_args.count(input_num) > 0) { - // Handles compile-time constants. - const Tensor& input = constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_args.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { - // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - const OptionalTensor& variable = variable_args.at(input_num); - arg.name = variable.name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable.present) { - const Tensor& value = variable.value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; - } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do - // permit uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); - } +xla::StatusOr +XlaCompilationCache::BuildSignature( + const NameAttrList& function, + absl::Span args) { + Signature signature; + signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); + for (const XlaCompiler::Argument& arg : args) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + signature.arg_values.push_back(arg.constant_value); + break; + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kResource: + signature.arg_types.emplace_back(arg.type, arg.shape); + break; + default: + return errors::InvalidArgument( + "Unhandled argument kind in XlaCompilationCache: ", + arg.HumanString()); } } - - return Status::OK(); + return std::move(signature); } -} // namespace - Status XlaCompilationCache::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, @@ -227,25 +164,38 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { - // Set the compile threshold to 1 to implement CompileMode::kStrict. - int64 compile_threshold = - compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; - return CompileImpl(options, function, constant_args, variable_args, ctx, - compile_options, /*compile_single_op=*/false, + absl::optional compile_threshold; + if (compile_mode == CompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } + auto compile_fn = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + return compiler->CompileFunction(compile_options, function, args, result); + }; + return CompileImpl(options, function, args, compile_fn, /*compile_threshold=*/compile_threshold, out_compilation_result, out_executable); } +static bool IsMegamorphic(int64 compile_count, int64 execution_count) { + const int64 kCompileThreshold = 10; + const int64 kMinExecutionsPerCompile = 50; + + // This heuristic is trying to capture the following property: have we sunk a + // certain minimum amount of compile time into the cluster that didn't quite + // "pay off"? + return compile_count > kCompileThreshold && + execution_count < kMinExecutionsPerCompile * compile_count; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { @@ -253,54 +203,41 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl(options, name, constant_args, variable_args, ctx, - compile_options, - /*compile_single_op=*/true, /*compile_threshold=*/1, + auto compile_op = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + std::vector result_dtypes(ctx->num_outputs()); + for (int i = 0; i < result_dtypes.size(); ++i) { + result_dtypes[i] = ctx->expected_output_dtype(i); + } + return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), + args, result_dtypes, result); + }; + return CompileImpl(options, name, args, compile_op, + /*compile_threshold=*/absl::nullopt, out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, - int64 compile_threshold, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { - VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << constant_args.size() - << " num_variable_args=" << variable_args.size(); - for (int i = 0; i < ctx->num_inputs(); i++) { - TensorShape shape = ctx->input(i).shape(); - VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i)) - << " present=" << ctx->has_input(i) - << " shape=" << shape.DebugString(); - } - for (auto& iterator : variable_args) { - const OptionalTensor& variable = iterator.second; - VLOG(2) << "variable present=" << variable.present - << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString() - << " TF arg= " << iterator.first; - } - VLOG(2) << "num_outputs = " << ctx->num_outputs(); - for (int i = 0; i < ctx->num_outputs(); i++) { - VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0; i < args.size(); i++) { + VLOG(2) << i << ": " << args[i].HumanString(); } } - TF_RET_CHECK(constant_args.size() + variable_args.size() <= - ctx->num_inputs()); - - Signature signature; - TF_RETURN_IF_ERROR( - BuildSignature(function, constant_args, variable_args, ctx, &signature)); + TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); + VLOG(2) << "Signature: " << signature.HumanString(); - VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not // protect the contents of the cache entry. Entry* entry; @@ -319,25 +256,67 @@ Status XlaCompilationCache::CompileImpl( // (since they get the benefit of XLA right away without waiting for warmup) // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at // most one cluster-compilation's worth of compile time). - bool is_first_execution = [&] { + bool is_first_execution; + + // We avoid compiling clusters that have "gone megamorphic" i.e. have an + // excessive amount of shape dynamism. + bool is_megamorphic; + + { mutex_lock lock(cluster_compile_stats_mu_); auto it = cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) .first; - return it->second.execution_count++ == 0; - }(); + is_first_execution = it->second.execution_count++ == 0; + + // The is_megamorphic bit is "sticky". We assume clusters that have been + // observed to be megamorphic once stay megamorphic forever. + it->second.is_megamorphic |= + IsMegamorphic(/*compile_count=*/it->second.compile_count, + /*execution_count=*/it->second.execution_count); + is_megamorphic = it->second.is_megamorphic; + } // Acquire the cache entry lock and compile, if necessary. // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); int64 current_request_count = ++entry->request_count; + VLOG(2) << "Compilation cache entry hit: " << entry->compiled + << " signature: " << signature.HumanString() << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold.value_or(0); if (!entry->compiled) { - VLOG(2) << "Compilation cache miss for signature: " - << SignatureDebugString(signature) << " with request count " - << current_request_count << " and compile threshold " - << compile_threshold; - if (!is_first_execution && current_request_count < compile_threshold) { + const bool should_compile = [&] { + if (!compile_threshold.has_value()) { + // Lazy compilation is disabled. + return true; + } + + if (is_megamorphic) { + VLOG(3) << "Not compiling cluster " << function.name() + << " because it is megamorphic."; + return false; + } + + if (is_first_execution) { + return true; + } + + bool reached_compile_threshold = + current_request_count >= *compile_threshold; + if (!reached_compile_threshold) { + VLOG(3) + << "Not compiling cluster " << function.name() + << " because it has not reached compile threshold; threshold is " + << *compile_threshold << " execution count " + << current_request_count << "."; + } + return reached_compile_threshold; + }(); + + if (!should_compile) { + VLOG(2) << "Not compiling for signature: " << signature.HumanString(); *out_compilation_result = nullptr; *out_executable = nullptr; return Status::OK(); @@ -347,21 +326,12 @@ Status XlaCompilationCache::CompileImpl( const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) - std::vector args; - TF_RETURN_IF_ERROR( - BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - if (compile_single_op) { - entry->compilation_status = - compiler.CompileSingleOp(compile_options, signature.name, ctx, args, - &entry->compilation_result); - } else { - entry->compilation_status = compiler.CompileFunction( - compile_options, function, args, &entry->compilation_result); - } + entry->compilation_status = + compile_fn(&compiler, &entry->compilation_result); TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); entry->compilation_status = diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index b43e5d40e6402d24b80f7c689018d81e8a5d7f09..846d0c963dbfdf55f51120f2f138d12f5f63839b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -17,9 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/graph.pb.h" @@ -30,13 +33,6 @@ limitations under the License. namespace tensorflow { -// Struct that represents a possibly-absent Tensor. -struct OptionalTensor { - string name; // A descriptive name - bool present = false; // Is the tensor present? - Tensor value; // If present, what is the Tensor's value? -}; - // The XlaCompilationCache class caches the results of the XlaCompiler class, // which converts a Tensorflow graph into a compiled XLA compilation. // @@ -58,11 +54,7 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `constant_args` is a map of tensorflow argument number to its constant - // value. - // `variable_args` is a snapshot of the current values of the - // resource variable arguments to `function`; uninitialized variables are - // represented by an absent OptionalTensor. + // `args` is a description of the arguments to the computation. // // `compile_mode` controls the behavior of the compilation cache on a cache // miss. If `compile_mode` is `kLazy` then, based on some profitability @@ -78,9 +70,7 @@ class XlaCompilationCache : public ResourceBase { // outputs. Status Compile(const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, @@ -90,8 +80,7 @@ class XlaCompilationCache : public ResourceBase { // XlaCompiler::CompileFunction. Status CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -101,26 +90,6 @@ class XlaCompilationCache : public ResourceBase { string DebugString() override; - private: - // Common implementation of Compile and CompileSingleOp. - Status CompileImpl( - const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op, int64 compile_threshold, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // Takes `result` which has been compiled from a Tensorflow subgraph to a - // XLA computation already, and generates an XLA LocalExecutable `executable`. - Status BuildExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable); - - xla::LocalClient* const client_; - const DeviceType device_type_; - // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { @@ -137,14 +106,35 @@ class XlaCompilationCache : public ResourceBase { struct Hash { uint64 operator()(const Signature& signature) const; }; + + // Returns a human-readable description of the signature. + string HumanString() const; }; - static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, Signature* signature); + static xla::StatusOr BuildSignature( + const NameAttrList& function, + absl::Span args); + + private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); + + // Takes `result` which has been compiled from a Tensorflow subgraph to a + // XLA computation already, and generates an XLA LocalExecutable `executable`. + Status BuildExecutable(const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + std::unique_ptr* executable); + + xla::LocalClient* const client_; + const DeviceType device_type_; // The value associated with a cache entry. struct Entry { @@ -180,7 +170,13 @@ class XlaCompilationCache : public ResourceBase { // Cumulative time spent compiling the cluster. int64 cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; }; + mutex cluster_compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..018c7c219f445bdca17f4f8b060e3678fe1be9ee --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(XlaCompilationCacheTest, SignatureEquality) { + NameAttrList fn; + fn.set_name("afunction"); + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({4, 0}); + args[0].constant_value = Tensor(DT_INT32, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].type = DT_FLOAT; + args[0].constant_value = Tensor(DT_FLOAT, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].shape = TensorShape({0, 4}); + args[0].constant_value = Tensor(DT_FLOAT, {0, 4}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3, + XlaCompilationCache::BuildSignature(fn, args)); + + std::vector signatures = {s1, s2, s3}; + for (int i = 0; i < signatures.size(); ++i) { + for (int j = 0; j < signatures.size(); ++j) { + EXPECT_EQ(i == j, signatures[i] == signatures[j]) + << signatures[i].HumanString() << " " << signatures[j].HumanString(); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 31cb32e3059bc17e3cde36e5c9f90cc78a39e473..1fe612d43d10030675cf307b109e4dcc89cb2d79 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -187,8 +187,13 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); - return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - compile_options, result, executable); + + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arguments, variable_args, ctx, &args)); + + return cache->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index cbfeb38805038825917c16684b9c441818972042..9006dd514b166ad8291d2d437305e53de2a093a4 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,8 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -37,13 +37,15 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, const string& name_prefix, std::vector* devices) { - legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); + XlaDeviceFlags* flags = GetXlaDeviceFlags(); bool compile_on_demand = flags->tf_xla_compile_on_demand; XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = !compile_on_demand; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + compile_on_demand + ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested + : XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2289abd2df372620c05db900bd46d1cdf6174377..4201ff91a89b1bee370e6a43337c51abe3bf974a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } if (device_context_) { device_context_->Unref(); } @@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; + tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); std::shared_ptr stream; { mutex_lock lock(mu_); @@ -391,13 +395,46 @@ Status XlaDevice::Sync() { } if (!stream) return Status::OK(); - if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + Status status = stream->BlockHostUntilDone(); + { + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + } + TF_RETURN_IF_ERROR(status); + if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); } VLOG(1) << "XlaDevice::Sync completed"; return Status::OK(); } +void XlaDevice::Sync(const DoneCallback& done) { + VLOG(1) << "XlaDevice::Sync (asynchronous)"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) { + done(Status::OK()); + return; + } + + stream->ThenEnqueueOnBackgroundThread( + [this, stream, done](se::StreamExecutor*) { + tracing::ScopedActivity activity("XlaDevice::Sync::Callback", + /*is_expensive=*/true); + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + done(stream->ok() ? Status::OK() + : errors::Internal("XlaDevice::Sync() failed.")); + }); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -441,12 +478,55 @@ bool XlaDevice::RequiresSyncOnCompletion() const { return sync_on_completion_; } +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice* device) + : device_(device) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { + if (device_) { + mutex_lock lock(device_->mu_); + --device_->outstanding_asynchronous_operations_; + device_->outstanding_asynchronous_operations_cv_.notify_all(); + } +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + const XlaDevice::AsynchronousOperationHandle& other) + : device_(other.device_) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice::AsynchronousOperationHandle&& other) + : device_(other.device_) { + other.device_ = nullptr; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(const XlaDevice::AsynchronousOperationHandle& other) { + device_ = other.device_; + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; + return *this; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(XlaDevice::AsynchronousOperationHandle&& other) { + device_ = other.device_; + other.device_ = nullptr; + return *this; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - kernel_factory::OpKernelRegistrar::Factory factory = + OpKernel* (*factory)(OpKernelConstruction*) = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 8881b697bc863e58006361924f7761c2e5bba493..c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -112,6 +112,12 @@ class XlaDevice : public LocalDevice { // compute, host-to-device, and device-to-host communication. bool use_multiple_streams = false; + // A function that describes how the on-host shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared + // shapes for computations. Must be non-null. XlaCompiler::ShapeRepresentationFn shape_representation_fn; // If padded_shape_fn is empty, a default implementation that returns @@ -129,6 +135,7 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; + void Sync(const DoneCallback& done) override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -158,7 +165,30 @@ class XlaDevice : public LocalDevice { bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // A simple RAII handle. On construction the device's + // outstanding_asynchronous_operations_ field is incremented; on destruction + // it is decremented. + class AsynchronousOperationHandle { + public: + AsynchronousOperationHandle(XlaDevice* device); + ~AsynchronousOperationHandle(); + AsynchronousOperationHandle(const AsynchronousOperationHandle& other); + AsynchronousOperationHandle(AsynchronousOperationHandle&& other); + AsynchronousOperationHandle& operator=( + const AsynchronousOperationHandle& other); + AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + + private: + XlaDevice* device_ = nullptr; + }; + + AsynchronousOperationHandle CreateAsynchronousOperationHandle() { + return AsynchronousOperationHandle(this); + } + private: + friend class AsynchronousOperationHandle; + xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -221,6 +251,11 @@ class XlaDevice : public LocalDevice { // True if the device requires XlaDevice::Sync to be called on completion // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = false; + + // Count of outstanding asynchronous operations which must be zero on Sync() + // completion. + int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; + condition_variable outstanding_asynchronous_operations_cv_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index eb3cf27624bb76058c8f0cf2e999818434d38d9e..6e6532731e64bd42ee56aa719748988f321e0f17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -70,9 +70,12 @@ XlaDeviceContext::XlaDeviceContext( CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = - [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { return shape; }; + shape_representation_fn_ = [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -99,7 +102,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, CHECK(xla_tensor); Status status = [&]() -> Status { - TF_ASSIGN_OR_RETURN(TensorShape shape, + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn_(device_tensor->shape(), device_tensor->dtype())); @@ -111,9 +114,15 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal())); + // The cpu_tensor and literal that we created here hold the data of host + // tensor in descending layout. The layout could be different from layout in + // device_tensor (but the logical shape has to be the same). The + // transfer_manager is responsible to do corresponding transposing when + // transferring the data to device. xla::BorrowingLiteral literal( static_cast(DMAHelper::base(cpu_tensor)), - xla_tensor->shaped_buffer().on_host_shape()); + xla::ShapeUtil::MakeShape(shape.element_type(), + xla::AsInt64Slice(shape.dimensions()))); VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); @@ -183,8 +192,15 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); + // Transfer manager requires the shape of the shaped buffer to be the same as + // literal shape except for the layout. Set the literal to use xla_tensor's + // shape as it is derived from the cpu_tensor's shape using + // shape_representation_fn_. xla::MutableBorrowingLiteral literal; - TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal)); + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout( + xla_tensor->shaped_buffer().on_host_shape()), + cpu_tensor, &literal)); TensorReference ref(*device_tensor); transfer_manager_->TransferLiteralFromDevice( diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 241ea8f60df8b66a9a39e3e176ecd4119f27d780..adf0f994b84d9fbf918a5b2478aa7d106853e038 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/stack.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .TypeConstraint("T") \ .HostMemory("input"), \ - RetvalOp); + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("StackV2") \ + .Device(DEVICE) \ + .HostMemory("max_size") \ + .HostMemory("handle"), \ + StackOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("T", TYPES), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("elem_type", TYPES), \ + StackPopOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); -// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // and write the tensors they access in order to concatenate them into a batch. // We would need either to call out to an XLA computation to perform the diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 8f28b38b5e15052e9a14bd1ecf1b3047085d98f1..441970169581d53e0d8683b98d26712445b170ea 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, std::vector* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); @@ -53,24 +53,25 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, return Status::OK(); } - XlaDevice::Options options; - options.platform = platform.ValueOrDie(); - options.device_name_prefix = name_prefix; - options.device_name = DEVICE_XLA_GPU; - options.device_ordinal = 0; - options.compilation_device_name = DEVICE_GPU_XLA_JIT; - options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); - - // TODO(b/78468222): Uncomment after fixing this bug - // status = device->UseGpuDeviceInfo(); - // if (!status.ok()) { - // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, - // " device"); - // return status; - // } - - devices->push_back(device.release()); + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + XlaDevice::Options options; + options.platform = platform.ValueOrDie(); + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_GPU; + options.device_ordinal = i; + options.compilation_device_name = DEVICE_GPU_XLA_JIT; + options.use_multiple_streams = true; + auto device = absl::make_unique(session_options, options); + + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + " device number ", i); + return status; + } + + devices->push_back(device.release()); + } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index dc37362fd8611577317f62083796fd4d655e7066..e828bae865d630bd40f227943cdabb2d8d95ca48 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices( XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6e51bfca4a1504f8f11fe60159cb44b2ae19fa1b..3b0bda4caa161a7561a3098b89420329998ff8a7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -191,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { return Status::OK(); } -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -ScopedShapedBuffer ExtractSubShapedBuffer( - ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator) { - const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_host_shape(), index); - const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_device_shape(), index); - - ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, - shaped_buffer->platform(), - shaped_buffer->device_ordinal()); - - auto& shape_tree = shaped_buffer->buffers(); - auto& sub_shape_tree = sub_shaped_buffer.buffers(); - sub_shape_tree.CopySubtreeFrom(shape_tree, - /*source_base_index=*/{index}, - /*target_base_index=*/{}); - shape_tree.ForEachMutableElement( - [index](const xla::ShapeIndex& shape_index, - tensorflow::se::DeviceMemoryBase* data) { - // shape_index is empty for the root node. Ignore that. - if (!shape_index.empty() && shape_index[0] == index) { - *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); - } - }); - return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); -} -} // namespace internal -using internal::ExtractSubShapedBuffer; - XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) @@ -391,8 +357,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); if (use_multiple_streams_) { xla_tensor->ResetDefinitionEvent(definition_event, stream); } @@ -445,7 +410,6 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - se::DeviceMemoryBase buffer = output.buffer({output_num}); if (variable_infos[i].var()->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); @@ -455,18 +419,20 @@ Status XlaComputationLaunchContext::PopulateOutputs( Tensor output_tensor; TF_RETURN_IF_ERROR( ctx->allocate_temp(write.type, write.shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); - if (use_multiple_streams_) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); + if (write.shape.num_elements() > 0) { + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); + if (use_multiple_streams_) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); + } } *variable_infos[i].var()->tensor() = output_tensor; } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable_infos[i].var()->tensor() = output_tensor; } ++output_num; @@ -474,4 +440,60 @@ Status XlaComputationLaunchContext::PopulateOutputs( return Status::OK(); } +Status XlaComputationLaunchContext::BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args) { + args->resize(ctx->num_inputs()); + + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { + XlaCompiler::Argument& arg = (*args)[input_num]; + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); + arg.constant_value = input; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); + } else { + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 81e205d13f711a701026b82100c17423595919ed..437db019a0eabe66417725148d8b121842e90479 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -35,6 +35,13 @@ limitations under the License. namespace tensorflow { class XlaAllocator; +// Struct that represents a possibly-absent Tensor. +struct OptionalTensor { + string name; // A descriptive name + bool present = false; // Is the tensor present? + Tensor value; // If present, what is the Tensor's value? +}; + // Takes a snapshot of the values of resource variable arguments, whose indices // are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is @@ -139,6 +146,13 @@ class XlaComputationLaunchContext { bool allocate_xla_tensors, bool use_multiple_streams); + // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch + // op. + static Status BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args); + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. // @@ -223,17 +237,6 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; -// Exposed in this header file for microbenchmarking purposes, but this is an -// internal implementation detail. -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -xla::ScopedShapedBuffer ExtractSubShapedBuffer( - xla::ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator); -} // namespace internal - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc deleted file mode 100644 index a45932403ec1760d6b985d5357fd6d84fbf257a2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Contains microbenchmarks for performance critical functions in -// xla_launch_util.cc. - -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs -// (cardinality of each non-leaf node's children). -void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { - tensorflow::testing::StopTiming(); - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); - for (int i = 0; i < depth; ++i) { - std::vector shapes(fan_out, shape); - shape = xla::ShapeUtil::MakeTupleShape(shapes); - } - xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, - /*device_ordinal=*/0); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - // Extract a buffer from approximately the middle of the first level of the - // tree. - (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) - .release(); - } -} - -BENCHMARK(BM_ExtractSubBuffer) - ->ArgPair(1, 4) - ->ArgPair(1, 8) - ->ArgPair(1, 32) - ->ArgPair(1, 64) - ->ArgPair(1, 128) - ->ArgPair(1, 256) - ->ArgPair(1, 512) - ->ArgPair(2, 4) - ->ArgPair(2, 8) - ->ArgPair(2, 32) - ->ArgPair(2, 64) - ->ArgPair(2, 128); - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - tensorflow::testing::RunBenchmarks(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 6f8b198262dfb46b3fd76c52b5c005778cb906eb..d1f7f754c8338487557eda512c56be34c9e958b7 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -43,11 +43,10 @@ namespace tensorflow { } } -Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, +Status XlaTensor::AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal) { - xla::Shape on_host_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); xla::Shape on_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6d7a6fd66c80f6b8c29ad7adb4c9ae8505f5ed81..77e80aa2527ecc2221ac61f7b7e6ebcce0982931 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -50,7 +50,7 @@ class XlaTensor { // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // is replaced and the managed memory deallocated. - Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal); // Some Tensors can have complex on-device shapes, including tuple shapes. To diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 664df006232f399ced0f29bf786c023a9688e64f..2b88a64fed322f662b3ff1d6bf706a813c52c758 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -375,6 +375,27 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "resampler_ops_test", + size = "small", + srcs = ["resampler_ops_test.py"], + disabled_backends = [ + # TODO(b/74459949) Support BatchDot in CPU backend. + "cpu", + "cpu_ondemand", + ], + # TODO(b/112295522): figure out how to make OSS build pass. + tags = ["no_oss"], + deps = [ + ":xla_test", + "//tensorflow/contrib/resampler:resampler_ops", + "//tensorflow/contrib/resampler:resampler_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -449,12 +470,11 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops", + "//tensorflow/python/ops/signal", ], ) @@ -816,8 +836,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - # Stack ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index 69fb3ec2964a09508e612515b9e291fc14121d68..e9c2d363acab96c0fb968cb7f901ce105ea8703e 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() @@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # similarly for others. self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: @@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAWithL1(self): for dtype in self.float_types: @@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.895489, -1.59555]), var0.eval()) + np.array([-0.895489, -1.59555]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.085339, -0.17989]), var1.eval()) + np.array([-0.085339, -0.17989]), self.evaluate(var1)) def testAdagradDAWithL1_L2(self): for dtype in self.float_types: @@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.046907, -0.093659]), var0.eval()) + np.array([-0.046907, -0.093659]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.004275, -0.009023]), var1.eval()) + np.array([-0.004275, -0.009023]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index ab69319c59fb07e7ce56c3c287a50a6290effdfd..e26483303c3934fd51675cb1fbc998b276caf527 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testTensorLearningRate(self): @@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testSharing(self): @@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values. - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Mix the first and the second adagrad for 3 steps. ada_update1.run() ada_update2.run() ada_update1.run() # Validate updated params (the same as with only 1 Adagrad). self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 058576b3d4b695209952158769162bb24e7ccfce..8bcff9d379d34f8a6bb8b0fdc60b7588c6d80be9 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRate(self): for dtype in self.float_types: @@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSharing(self): for dtype in self.float_types: @@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase): beta1_power, beta2_power = opt._get_beta_accumulators() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of intertwined Adam1 and Adam2. for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) if t % 2 == 0: update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) else: @@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase): var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index 3ed1d41b7121f44dd7470f61180f7a7055369174..961b46375c941bdc3922e460a2f58345086dbceb 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() @@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): for t in range(1, 4): update.run() - self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-2) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-2) self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) @@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() # Run 3 steps of AdaMax for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) update.run() var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 1bc07ace23ccdc83103abe71ee11b72994c75a6d..a37c97e6d374440aeb860b9d02f2d5dd95c91f62 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of AddSign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - var0_np, var0.eval(), half_rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + var0_np, self.evaluate(var0), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 4e6dd6abfc9cdbabbbcdf0734be828f0aa28683b..332381c59eed06d5697e58efb1d8fa2b6ef604d2 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + import numpy as np from tensorflow.compiler.tests import xla_test @@ -967,7 +969,7 @@ class BinaryOpsTest(xla_test.XLATestCase): self._testBinary( array_ops.expand_dims, np.array([42], dtype=dtype), - np.int32(0), + np.array([0], dtype=np.int64), expected=np.array([[42]], dtype=dtype)) self._testBinary( array_ops.expand_dims, @@ -994,15 +996,21 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[[1, 2], [3, 4]]], dtype=dtype), np.int32(3), expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.array([2], dtype=np.int64), + expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) def testPad(self): - for dtype in self.numeric_types: + for dtype, pad_type in itertools.product( + self.numeric_types, [np.int32, np.int64]): self._testBinary( array_ops.pad, np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[1, 2], [2, 1]], dtype=np.int32), + [[1, 2], [2, 1]], dtype=pad_type), expected=np.array( [[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], @@ -1016,7 +1024,7 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[0, 3], [2, 1]], dtype=np.int32), + [[0, 3], [2, 1]], dtype=pad_type), expected=np.array( [[7, 7, 1, 2, 3, 7], [7, 7, 4, 5, 6, 7], diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a57d1dc81ea2c9c188b0a3005904738aa8156bf3..15108487cfa8b9f07a5705fa6897fe16375ad7bf 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -60,7 +61,7 @@ class CategoricalTest(xla_test.XLATestCase): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) - d = sess.run(op) + d = self.evaluate(op) batch_size, num_classes = logits.shape freqs_mat = [] @@ -85,9 +86,9 @@ class CategoricalTest(xla_test.XLATestCase): # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -112,7 +113,7 @@ class CategoricalTest(xla_test.XLATestCase): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, output_dtype=output_dtype) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= 0).sum() == 1000) self.assertTrue((y < 20).sum() == 1000) @@ -138,6 +139,36 @@ class CategoricalTest(xla_test.XLATestCase): chi2 = self._chi2(probs, freqs) self.assertLess(chi2, 1e-3) + def testStatelessMultinomialIsInRange(self): + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), + 1000, + seed_t, + output_dtype=output_dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testDeterminismMultinomial(self): + # Stateless values should be equal iff the seeds are equal (roughly) + num_samples = 10 + with self.cached_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + pure = stateless_random_ops.stateless_multinomial( + logits, num_samples, seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 88bd58b2da6b2892f898ad10f3467d8ce39d6388..ef2d7af69deeebd5f4c4c7225d7027f8f76bf861 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") output = math_ops.add(input1, input2) - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testAddFromCpuMultiple(self): @@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase): with self.test_scope(): output = math_ops.add(input1, input2) for _ in xrange(10): - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testDeadlock(self): diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2d225ad226cac368042b95eae8fc29e6fd8e82e0..deb9ac186e63a520054993cb56375f152c8c6587 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase): x2 = constant_op.constant(p2) with self.test_scope(): c = array_ops.concat([x1, x2], 0) - result = c.eval() + result = self.evaluate(c) self.assertAllEqual(result[:2, :], p1) self.assertAllEqual(result[2:, :], p2) @@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 1) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) def testGradientsSimpleAll(self): @@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 0) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 2) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, concat_dim) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) - self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) def testConcatNoScalars(self): with self.cached_session(): @@ -337,7 +337,7 @@ class ConcatOffsetTest(xla_test.XLATestCase): s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) - ans = sess.run(off) + ans = self.evaluate(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) @@ -350,7 +350,7 @@ class PackTest(xla_test.XLATestCase): s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): @@ -360,7 +360,7 @@ class PackTest(xla_test.XLATestCase): s1 = constant_op.constant(3, dtypes.int32) s2 = constant_op.constant(5, dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): @@ -370,7 +370,7 @@ class PackTest(xla_test.XLATestCase): s1 = constant_op.constant([[]], dtypes.int32) s2 = constant_op.constant([[]], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[[]], [[]], [[]]]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index d59fd0236f4f7da2bbfb3409342c7f70f8f5d1f6..01cc1b6392845be2418c50d55be97487eb290843 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) # We count the number of cells being added at the locations in the output. # At the center, #cells = kernel_depth * kernel_height * kernel_width @@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) for n in xrange(x_shape[0]): for k in xrange(f_shape[3]): @@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="VALID") - value = output.eval() + value = self.evaluate(output) cache_values = np.zeros(y_shape, dtype=np.float32) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 63cee550fde9d9d4314b1541fba191df776a4da2..76706ad40a0f0e9d033196d2e32e9b6c154268f0 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -106,7 +106,7 @@ class EagerTest(xla_test.XLATestCase): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five - self.assertAllEqual(15, sess.run(product)) + self.assertAllEqual(15, self.evaluate(product)) def testDegenerateSlices(self): with self.test_scope(): diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b3e13fbaa6b33bdaa1be123be558059e96de282e..61abf9c9c045b835b3a2e92fc588cd31f3da76ff 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -24,11 +24,10 @@ import numpy as np import scipy.signal as sps from tensorflow.compiler.tests import xla_test -from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.signal import signal from tensorflow.python.platform import googletest BATCH_DIMS = (3, 5) @@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase): def testFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, - spectral_ops.fft) + signal.fft) def testFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, - spectral_ops.fft2d) + signal.fft2d) def testFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), - spectral_ops.fft3d) + signal.fft3d) def testIFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, - spectral_ops.ifft) + signal.ifft) def testIFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, - spectral_ops.ifft2d) + signal.ifft2d) def testIFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), - spectral_ops.ifft3d) + signal.ifft3d) def testRFFT(self): self._VerifyFftMethod( INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), - lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + lambda x: signal.rfft(x, fft_length=[x.shape[-1].value])) def testRFFT2D(self): def _tf_fn(x): - return spectral_ops.rfft2d( + return signal.rfft2d( x, fft_length=[x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod( @@ -153,7 +152,7 @@ class FFTTest(xla_test.XLATestCase): x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) def _tf_fn(x): - return spectral_ops.rfft3d( + return signal.rfft3d( x, fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) @@ -162,7 +161,7 @@ class FFTTest(xla_test.XLATestCase): def testIRFFT(self): def _tf_fn(x): - return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), @@ -171,7 +170,7 @@ class FFTTest(xla_test.XLATestCase): def testIRFFT2D(self): def _tf_fn(x): - return spectral_ops.irfft2d( + return signal.irfft2d( x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( @@ -195,7 +194,7 @@ class FFTTest(xla_test.XLATestCase): s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) def _tf_fn(x): - return spectral_ops.irfft3d( + return signal.irfft3d( x, fft_length=[ x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 8c7edfd277c992c35a81dd5f261256a86352254e..91d77d2f791834346f43aecb60d116ddbf2faa6e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase): enqueue_op.run() for i in xrange(len(elems)): - vals = dequeued_t.eval() + vals = self.evaluate(dequeued_t) self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): @@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([], size.get_shape()) enqueue_op.run() - self.assertEqual(1, size.eval()) + self.assertEqual(1, self.evaluate(size)) dequeued_t.op.run() - self.assertEqual(0, size.eval()) + self.assertEqual(0, self.evaluate(size)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 5b197afd655404e4e36a8b3442f8db60cb1d648d..b078053cdbd6d129645734492d34dd25d28ab3ef 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivAdagradTest_AdagradPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Adagrad for a few steps for _ in range(steps): adagrad_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_FtrlPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run GradientDescent for a few steps for _ in range(steps): sgd_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testFtrlwithoutRegularization(self): for dtype in self.float_types: @@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-2.60260963, -4.29698515]), - var0.eval(), + self.evaluate(var0), float_rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType( np.array([-0.28432083, -0.56694895]), - var1.eval(), + self.evaluate(var1), float_rtol=1e-5, half_rtol=1e-2) @@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5, + np.array([-2.55607247, -3.98729396]), + self.evaluate(var0), + 1e-5, + 1e-5, float_rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) + np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5, + 1e-5) def testFtrlWithL1(self): for dtype in self.float_types: @@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-7.66718769, -10.91273689]), - var0.eval(), + self.evaluate(var0), rtol=1e-4, bfloat16_rtol=1e-1, bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( - np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) + np.array([-0.93460727, -1.86147261]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + np.array([-0.24059935, -0.46829352]), + self.evaluate(var0), + rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) + np.array([-0.02406147, -0.04830509]), + self.evaluate(var1), + rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), + self.evaluate(var0), + rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" @@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): update1 = opt1.apply_gradients([(grads1, var1)]) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # var0 is experiencing L2 shrinkage so it should be smaller than var1 # in magnitude. - self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all()) accum0 = list(opt0._slots["accum"].values())[0].eval() accum1 = list(opt1._slots["accum"].values())[0].eval() # L2 shrinkage should not change how we update grad accumulator. diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index b1891b918c6584abce9da382088ed0037f5319fb..dd9b7f30efedaa45c96e60290b14a42d7f969b34 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testNestedFunctions(self): @@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_g = Foo(a, b) - result = sess.run(call_g) + result = self.evaluate(call_g) self.assertAllClose(result, expected, rtol=1e-3) def testFunctionMultipleRetvals(self): @@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testCompileTimeConstantsInDefun(self): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 561715ee1c3e0db37169cfd3fb431c0872987d75..6f51ae33a1b0fc8670ddf0cacb03a3b5a9176a91 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -593,6 +593,67 @@ class LazyCompilationTest(test.TestCase): self.assertFalse( InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + def testIsMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Make the cluster go megamorphic by running it with lots of shape + # signatures where the cluster is executed with each signature only a few + # times. Then check that we don't compile the cluster ever again. + + for shape in range(10, 50): + for _ in range(0, 49): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 50): + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + + def testIsNotMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Run the cluster with lots of shape signatures, but in a way that it + # isn't megamorphic (i.e. each shape signature sees a lot of executions). + # Then check that the cluster has not been marked as megamorphic. + + for shape in range(10, 50): + for _ in range(0, 1000): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 10): + sess.run(y, feed_dict={x: [0.] * 60}) + + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index c6ad67993e8bc196a74c9a328df8c9200c92c575..5dddf6ae4e8c8a3d5e9eb7b2c62298df02a0093c 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase): with self.test_scope(): actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, depth_radius, bias, alpha, beta) - expected_val = expected.eval() - actual_val = actual.eval() + expected_val = self.evaluate(expected) + actual_val = self.evaluate(actual) self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 265c0b6d1412de7be3a5bf5e79129cb330ceb162..fd02a50aff94d2bd2e180a092a27c8195178c5e5 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -88,7 +88,7 @@ class LSTMTest(test.TestCase): (basename, m_prev_scalar, c_prev_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM step. - sess.run(variables.global_variables_initializer()) + self.evaluate(variables.global_variables_initializer()) return sess.run([m, c]) def testLSTMCell(self): @@ -173,7 +173,7 @@ class LSTMTest(test.TestCase): (basename, m_init_scalar, c_init_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM layer. - sess.run(variables.global_variables_initializer()) + self.evaluate(variables.global_variables_initializer()) return sess.run(out_seq) def testLSTMLayer(self): diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index f77521a7c49dba39849869ddceb7c0e885147722..3416f7dbd6bdd264bf79785084f981f5b07cb8a9 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) def testNesterovMomentum(self): for dtype in self.float_types: @@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase): var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) var1_np, accum1_np = self._update_nesterov_momentum_numpy( var1_np, accum1_np, 0.9, 0.1, 0.9) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: @@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 77bb839409f0c323ff6ed2c8d6bd105d3003b398..9671ae0ae973ff82d22744a1feb9b4293d94bbdd 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase): ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 sess.run(variables.variables_initializer([v])) - self.assertEqual(8.0, sess.run(out)) + self.assertEqual(8.0, self.evaluate(out)) def test_placeholder_with_default_fed(self): with self.cached_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 86536da7fed0e2309beb32fee9c7c605491592ed..5b35c20027700b34500a31e174061d7087094b61 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -91,8 +91,8 @@ class PowerSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of powersign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class PowerSignTest(xla_test.XLATestCase): ) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index c41b4171e26af4f7ad0237d7407a5b3691299595..63cc51a470164915b2614a06d18ca1850bb64a3c 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -45,15 +45,17 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) - self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + self.assertAllClose( + np.array([-2.60260963, -4.29698515]), self.evaluate(var0)) + self.assertAllClose( + np.array([-0.28432083, -0.56694895]), self.evaluate(var1)) opt_vars = opt.variables() self.assertStartsWith(opt_vars[0].name, var0._shared_name) self.assertStartsWith(opt_vars[1].name, var1._shared_name) @@ -74,14 +76,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) - self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + self.assertAllClose(np.array([-1.60261, -2.296985]), self.evaluate(var0)) + self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) def testProximalAdagradWithL1(self): with self.cached_session(), self.test_scope(): @@ -98,14 +100,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad for _ in range(10): update.run() - self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) - self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + self.assertAllClose(np.array([-6.663634, -9.190331]), self.evaluate(var0)) + self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) def testProximalAdagradWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -122,15 +124,15 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad. for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -141,14 +143,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivAdagradwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 3d808e6b8a71ef9fa60b671d07bfd907e9f58efc..5aec433be765dd0a04bd7ab10d5c39a5a7f48c5c 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -42,15 +42,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent. for _ in range(3): update.run() - self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) - self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + self.assertAllClose(np.array([-0.9, -1.8]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) def testProximalGradientDescentwithoutRegularization2(self): with self.cached_session(), self.test_scope(): @@ -64,15 +64,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent for _ in range(3): update.run() - self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) - self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + self.assertAllClose(np.array([0.1, 0.2]), self.evaluate(var0)) + self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) def testProximalGradientDescentWithL1(self): with self.cached_session(), self.test_scope(): @@ -86,15 +86,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps proximal gradient descent. for _ in range(10): update.run() - self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) - self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + self.assertAllClose(np.array([-1.988, -3.988001]), self.evaluate(var0)) + self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) def testProximalGradientDescentWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -108,15 +108,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Gradient Descent for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -127,14 +127,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivGradientDescentwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 236b1b881dcaffc1a5b0c6395f0605c1d7ef0269..b4d4193e35f9e0e3b23d0242ed076dd811f4ee2b 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -63,7 +63,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) - precision = self.AdjustedNorm(xx.eval() - identity.eval()) + precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity)) self.assertTrue(np.all(precision < 5.0)) def _test(self, dtype, shape, full_matrices): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206..1e913909452d54ed59f33bb0d313fd062570d459 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase): # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) @@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) - y = sess.run(x) + y = self.evaluate(x) def normal_cdf(x): return .5 * math.erfc(-x / math.sqrt(2)) @@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. @@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = np.diag(range(20)).flatten() # Compare sets to avoid randomness behavior changes but make sure still # have all the values. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index cfccf5f3d2a0a3f2910b2ac1c2747381b172a685..a6b58020126a3297944f199e99b0801387615564 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2466,20 +2466,21 @@ TEST_F(OpTest, Pack) { }); } -// TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); - // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. - // DataType tpaddings = Choose({DT_INT32, DT_INT64}); - DataType tpaddings = DT_INT32; + DataType tpaddings = Choose({DT_INT32, DT_INT64}); std::vector paddings_vec; - std::uniform_int_distribution distribution(0, 7); for (int i = 0; i < t_dims.size(); ++i) { - paddings_vec.push_back(distribution(generator())); - paddings_vec.push_back(distribution(generator())); + std::uniform_int_distribution pad_distribution(0, t_dims[i]); + int pad_size = pad_distribution(generator()); + std::uniform_int_distribution lower_distribution(0, pad_size); + int low_pad_size = lower_distribution(generator()); + paddings_vec.push_back(low_pad_size); + paddings_vec.push_back(pad_size - low_pad_size); + t_dims[i] -= pad_size; } Tensor paddings; CHECK( diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/compiler/tests/resampler_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ca0eab276b39f025d018edebb78eed7a8433bb --- /dev/null +++ b/tensorflow/compiler/tests/resampler_ops_test.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for resampler ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib import resampler +from tensorflow.contrib.resampler.ops import gen_resampler_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ResamplerOpsTest(xla_test.XLATestCase): + + def _assertForwardOpMatchesExpected(self, image_np, warp_np, expected): + with self.test_session() as sess, self.test_scope(): + input_image = array_ops.placeholder(image_np.dtype) + warp = array_ops.placeholder(warp_np.dtype) + resampled = resampler.resampler(input_image, warp, name='resampler') + out = sess.run(resampled, {input_image: image_np, warp: warp_np}) + + self.assertAllCloseAccordingToType( + expected, out, rtol=5e-3, half_rtol=1e-2, bfloat16_rtol=3e-2) + + def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, + expected_grad_data, expected_grad_warp): + with self.cached_session() as sess, self.test_scope(): + input_image = array_ops.placeholder(input_np.dtype) + warp = array_ops.placeholder(warp_np.dtype) + grad_output = array_ops.placeholder(grad_output_np.dtype) + + grad_data, grad_warp = gen_resampler_ops.resampler_grad( + input_image, warp, grad_output) + + grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], { + input_image: input_np, + warp: warp_np, + grad_output: grad_output_np + }) + + self.assertAllCloseAccordingToType( + expected_grad_warp, grad_warp_tf, half_rtol=1e-2, bfloat16_rtol=3e-2) + self.assertAllCloseAccordingToType( + expected_grad_data, grad_data_tf, half_rtol=1e-2, bfloat16_rtol=3e-2) + + def testSimple(self): + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [0, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2] + warp_data = [0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[26.42]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + grad_output = np.ones([1, 1], dtype=dtype) + + expected_grad_data = [[[[0.12], [0.27999997]], [[0.18000001], + [0.42000002]]]] + + expected_grad_warp = [[26.60000038, 38.20000076]] + + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + def testMultiChannel(self): + for dtype in self.float_types: + input_shape = [1, 2, 2, 3] + input_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2] + warp_data = [0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[59.58000183, 146.94000244, 107.37999725]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + grad_output = np.ones([1, 3], dtype=dtype) + + expected_grad_data = [[[[0.12, 0.12, 0.12], + [0.27999997, 0.27999997, 0.27999997]], + [[0.18000001, 0.18000001, 0.18000001], + [0.42000002, 0.42000002, 0.42000002]]]] + + expected_grad_warp = [[199, 30]] + + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + def testBatch2Height3byWidth3RGB(self): + for dtype in self.float_types: + input_shape = [2, 3, 3, 3] + input_rgb_data = [ + 0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1, 30, 105, 2, 40, 115, + 3, 50, 125, 4, 60, 135, 5, 70, 145, 6, 0, 5, 13, 54, 135, 226, 37, 8, + 234, 90, 255, 1, 30, 105, 2, 40, 115, 3, 50, 125, 4, 60, 135, 5, 70, + 145, 6 + ] + input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape) + + # 2 batches and 2 samples for each batch. + warp_shape = [2, 2, 2] + warp_data = [0.7, 0.6, 1, 0.7, 0.9, 1.2, 1.3, 1.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + + expected_forward = [[[43.92, 128.4, 65.86], [37.2, 114., 69.2]], + [[40.6, 122.8, 2.5], [51., 126, 4.1]]] + + self._assertForwardOpMatchesExpected(input_np, warp_np, expected_forward) + + expected_grad_data = [[[[0.12, 0.12, 0.12], + [0.57999998, 0.57999998, 0.57999998], + [0., 0., 0.]], + [[0.18000001, 0.18000001, 0.18000001], + [1.12, 1.12, 1.12], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0.08000001, 0.08000001, 0.08000001], + [0.99999988, 0.99999988, 0.99999988], + [0.11999997, 0.11999997, 0.11999997]], + [[0.02000001, 0.02000001, 0.02000001], + [0.60000008, 0.60000008, 0.60000008], + [0.17999998, 0.17999998, 0.17999998]]]] + expected_grad_warp = [[[33.39999008, -96.20000458], [-26.10000229, + -278.]], + [[-162.99998474, 39.99999619], [21., 63.]]] + + grad_output = np.ones([2, 2, 3], dtype=dtype) + self._assertBackwardOpMatchesExpected(input_np, warp_np, grad_output, + expected_grad_data, + expected_grad_warp) + + def testOutOfBoundWarps(self): + # (x, y) are both less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, -1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is less than 0. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-1, 0.1, 0.7, 0.6] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [27.62]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # Both of (x, y) are greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [-0.1, 0.1, 1.2, 2.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + # One of (x, y) is greater than image size. + for dtype in self.float_types: + input_shape = [1, 2, 2, 1] + input_data = [10, 5, 13, 54] + input_np = np.array(input_data, dtype=dtype).reshape(input_shape) + + warp_shape = [1, 2, 2] + warp_data = [0.1, -0.1, 1.2, 0.1] + warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape) + expected = [[[0.0], [0.0]]] + self._assertForwardOpMatchesExpected(input_np, warp_np, expected) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index 8840a1329a907bddc6ef1cb6dd1c2a6d234def5c..dc3e90b4afa41c08d899ee195d42fb91678bad1c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -76,7 +76,7 @@ class RmspropTest(xla_test.XLATestCase): rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) rms_update = rms_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) mg0 = rms_opt.get_slot(var0, "mg") self.assertEqual(mg0 is not None, centered) @@ -92,12 +92,12 @@ class RmspropTest(xla_test.XLATestCase): self.assertTrue(mom1 is not None) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of RMSProp for _ in range(3): - rms_update.run() + self.evaluate(rms_update) var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( var0_np, @@ -118,14 +118,14 @@ class RmspropTest(xla_test.XLATestCase): # Validate updated params if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) + self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) + self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) + self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) + self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) + self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 46ca371c8abf1cb4710717a183ee12820c4c4ca0..d7e26d79c4c054860ade5c8960a3bca984e020b0 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -79,7 +79,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() self.assertAllEqual( - convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), + self.evaluate(c0)) def testTensorArrayWritePack(self): for dtype in self.numeric_tf_types: @@ -97,7 +98,7 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() - self.assertAllEqual([3, 0, 1], c0.eval().shape) + self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape) def _testTensorArrayWriteConcat(self, tf_dtype): with self.cached_session(), self.test_scope(): @@ -113,8 +114,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.concat() self.assertAllEqual( - convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], - [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], + [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: @@ -341,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): - r0_bad.eval() + self.evaluate(r0_bad) # Test reading from a different index than the one we wrote to w0.read(1) @@ -422,7 +423,7 @@ class TensorArrayTest(xla_test.XLATestCase): w2 = h2.write(0, 5.0) r2 = w2.read(0) r = r1 + r2 - self.assertAllClose(9.0, r.eval()) + self.assertAllClose(9.0, self.evaluate(r)) def _testTensorArrayGradientWriteReadType(self, dtype): with self.cached_session() as session, self.test_scope(): @@ -504,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase): [-0.5, 1.5], # read(0) gradient [20.0, 30.0, 40.0, 50.0], # concat gradient ]) - grad_vals = sess.run(grad_r) # 2 + 2 entries + grad_vals = self.evaluate(grad_r) # 2 + 2 entries self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) @@ -526,7 +527,7 @@ class TensorArrayTest(xla_test.XLATestCase): with ops.control_dependencies([r0_readtwice]): r1_readtwice = w_readtwice.read(0) - self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice)) def _testTensorArrayGradientUnpackRead(self): with self.cached_session() as session, self.test_scope(): @@ -592,7 +593,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() - self.assertAllEqual(3, s.eval()) + self.assertAllEqual(3, self.evaluate(s)) def testWriteCloseTensorArray(self): with self.cached_session(), self.test_scope(): @@ -722,7 +723,7 @@ class TensorArrayTest(xla_test.XLATestCase): # r = acc2.stack() # grad = gradients_impl.gradients(r, [x])[0] - # self.assertAllClose(31.0, grad.eval()) + # self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): with self.cached_session() as session, self.test_scope(): @@ -912,7 +913,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(0, ta.size().eval()) ta = ta.unstack(array_ops.zeros([0, 3, 5])) packed = ta.stack() - self.assertAllEqual([0, 3, 5], packed.eval().shape) + self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero self.assertAllEqual([0, 5], ta.concat().eval().shape) @@ -1041,8 +1042,8 @@ class TensorArrayTest(xla_test.XLATestCase): (read0, read1, size0, size1)) # Tests that the control dependencies was added and executed. - self.assertEqual(1, v0.eval()) - self.assertEqual(1, v1.eval()) + self.assertEqual(1, self.evaluate(v0)) + self.assertEqual(1, self.evaluate(v1)) # Tests correct TensorArray. self.assertEqual(read0_v, 0) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dd2c252d383bca9c59033ac07e442b487e4975a6..e776c8a951c7ac24c65408a67007b03ae07e8be0 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" + def testWriteEmptyShape(self): + # Verifies that we can pass an uninitialized variable with an empty shape, + # assign it a value, and successfully return it. + for dtype in self.numeric_types: + with self.test_session() as sess, self.test_scope(): + zeros = np.zeros([3, 0], dtype=dtype) + v = resource_variable_ops.ResourceVariable(zeros) + p = array_ops.placeholder(dtype) + x = v.assign(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(zeros, sess.run(y, {p: zeros})) + def testOneWriteOneOutput(self): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. @@ -216,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[3], [7]]) + self.assertAllEqual(self.evaluate(read), [[3], [7]]) def testScatterSub(self): with self.test_session() as sess, self.test_scope(): @@ -229,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[4], [-1]]) + self.assertAllEqual(self.evaluate(read), [[4], [-1]]) def testScatterMul(self): with self.test_session() as sess, self.test_scope(): @@ -242,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): @@ -255,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[2]]) + self.assertAllEqual(self.evaluate(read), [[2]]) def testScatterMin(self): with self.test_session() as sess, self.test_scope(): @@ -268,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMax(self): with self.test_session() as sess, self.test_scope(): @@ -281,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): @@ -294,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterAddScalar(self): with self.test_session() as sess, self.test_scope(): @@ -307,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterSubScalar(self): with self.test_session() as sess, self.test_scope(): @@ -320,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[-1]]) + self.assertEqual(self.evaluate(read), [[-1]]) def testScatterMulScalar(self): with self.test_session() as sess, self.test_scope(): @@ -333,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant(5, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDivScalar(self): with self.test_session() as sess, self.test_scope(): @@ -346,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[2]]) + self.assertEqual(self.evaluate(read), [[2]]) def testScatterMinScalar(self): with self.test_session() as sess, self.test_scope(): @@ -359,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): @@ -372,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -387,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) def testScatterNdUpdateAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -403,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase): gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) class StridedSliceAssignChecker(object): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5fc9a352ff930c7d281ec5c52168580e453c04b0..3458c7f1c40cd70187e209eb40db24245d595d04 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -166,6 +166,7 @@ cc_library( "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", + "xla_expression.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", @@ -180,6 +181,7 @@ cc_library( "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", + "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", @@ -193,6 +195,7 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -201,13 +204,13 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -217,6 +220,8 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -362,8 +367,12 @@ tf_cc_test( tf_cc_test( name = "xla_compiler_test", - srcs = ["xla_compiler_test.cc"], + srcs = [ + "xla_compiler_test.cc", + "xla_expression_test.cc", + ], deps = [ + ":common", ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", @@ -386,6 +395,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -428,14 +438,13 @@ cc_library( name = "dump_graph", srcs = [ "dump_graph.cc", - "dump_graph_flags.cc", - "dump_graph_flags.h", ], hdrs = [ "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 380c6a7e23da92d949b26876836b999bf6406c6c..1de85004a51bea464f8f0166511402e5dd85ac14 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -61,8 +61,7 @@ string MakeUniqueFilename(string name) { string WriteTextProtoToUniqueFile( Env* env, const string& name, const char* proto_type, const ::tensorflow::protobuf::Message& proto) { - const string& dirname = - legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; + const string& dirname = GetDumpGraphFlags()->tf_dump_graph_prefix; Status status = env->RecursivelyCreateDir(dirname); if (!status.ok()) { LOG(WARNING) << "Failed to create " << dirname << " for dumping " diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc deleted file mode 100644 index a6c908ba011afb90fabacc855df8c6afbb35d254..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include -#include - -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static DumpGraphFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new DumpGraphFlags; - flags->tf_dump_graph_prefix = "/tmp/"; - flag_list = new std::vector({ - Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h deleted file mode 100644 index 80a3307d920f2cc3d668d507786a02e43589f86f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// dump_graph module. -typedef struct { - string tf_dump_graph_prefix; // Path prefix to which graphs dumped during - // debugging should be written. -} DumpGraphFlags; - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f818d80022da0bad851c896f2714c15b20b22195..3dfd3f854c8646ebbf06d3378201d22e8741b7eb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -75,6 +75,25 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr, + graph_def, library); +} + +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library) { + FunctionDefLibrary function_lib = graph_def->library(); + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library)); + graph.ToGraphDef(graph_def); + std::swap(*graph_def->mutable_library(), function_lib); + return Status::OK(); +} + Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, @@ -242,23 +261,20 @@ Status FunctionalizeControlFlowPass::Run( continue; } const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - bool modified; - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + bool modified; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index ba99205640ccdc83a3a4d50e3ec474907894a835..91d33fa405834d7f1f8f66180583580f4f2e448a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -33,6 +33,12 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library); + // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional // control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c3841f996f801e855da75b23f01d41674ec51c4d..9784985af83a18619d837528f99a60b98a501ec5 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -95,77 +95,87 @@ TEST(FunctionalizeControlFlow, Conditional) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, - then_fn, else_fn); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -239,75 +249,77 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } -// @function.Defun(noinline=True) -// def increment_fn(x): -// return [x + 1] -// Define the above function, and add it to the given graph. It's used as the -// while loop body in NoinlineLoopBody test. -Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +FunctionDef GetNoinlineFunctionDef() { FunctionDef fdef = FunctionDefHelper::Create( "increment_fn", {"x:int32"}, {"add:int32"}, {}, { @@ -316,8 +328,17 @@ Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { }, {{"add", "add_0:z:0"}}); (*fdef.mutable_attr())["_noinline"].set_b(true); + return fdef; +} + +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { FunctionDefLibrary fdef_lib; - *(fdef_lib.add_function()) = fdef; + *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); NodeDef increment_fn; increment_fn.set_name(node_name); @@ -376,55 +397,88 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { FunctionLibraryDefinition lookup_lib(graph.flib_def()); FunctionLibraryDefinition library(OpRegistry::Global(), {}); // Function increment_fn will be copied from lookup_lib to library. - TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + *(optimized_graph_def.mutable_library()->add_function()) = + GetNoinlineFunctionDef(); - NameAttrList cond_fn, body_fn; - TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &lookup_lib, &optimized_graph_def, &library)); + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK( + AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } +} - // Body graph. +TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), source); TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - NodeDef retval; - retval.set_name("_retval0_RetVal"); - retval.set_op(FunctionLibraryDefinition::kRetOp); - *retval.add_input() = noinline_node_name; - (*retval.mutable_attr())["T"].set_type(DT_INT32); - (*retval.mutable_attr())["index"].set_i(0); - Status status; - scope.graph()->AddNode(retval, &status); - TF_ASSERT_OK(status); - - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } - InstantiationResultForTest result; - // Verify that increment_fn has been copied to library. - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + graph_def.clear_library(); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - // Ignore the function library when comparing the graphs. - expected.clear_library(); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + Status status = + FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library); + EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } // Tests functionalizing OneLoopVar where the loop value is not used post the @@ -467,65 +521,72 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -608,86 +669,95 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); - auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); - auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") .WithControlDependencies(arg0.output), - 3); - auto cond_add = - ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const( - scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - - auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); - auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - - auto one = ops::Const( - scope.WithOpName("while/add/one").WithControlDependencies(identity_x), - 1); - auto two = ops::Const( - scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), - 2); + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); - auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = + ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = + ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -841,177 +911,192 @@ TEST(FunctionalizeControlFlow, Complex) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList outer_cond_fn, outer_body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); - auto y = ops::Add(scope.WithOpName("y"), x, three); - - auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, - TensorShape({})); - - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - - auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Outer condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto ten = ops::Const( - scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Outer body graph. - NameAttrList inner_cond_fn, inner_body_fn; - { - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - - // Find the inner condition and body names. - TF_EXPECT_OK( - FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( - scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); - auto while_op = - ops::While(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); - - auto one_outer = ops::Const( - scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); - auto add_i = - ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(absl::Span{ - while_op[0].op(), while_op[1].op()}), - identity_i, one_outer); - - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto five = ops::Const( - scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); - auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList outer_cond_fn, outer_body_fn; TF_EXPECT_OK( - InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_j = - ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); - auto identity_k = - ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); - - auto mul_jk = - ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); - auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); - auto assign = ops::AssignAddVariableOp( - scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - - auto one = ops::Const( - scope.WithOpName("outer/inner/One") - .WithControlDependencies( - absl::Span{assign.operation}), - 1); - auto add_j = - ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); - auto retval1 = - ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), + 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(absl::Span{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 706ed4f5bbfac60de4653cc8c326214cd4d8d886..efb75749722893100494e089c0beb96944e9f1d4 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -51,12 +52,11 @@ namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { - auto builder = ctx->builder(); auto client = ctx->compiler()->client(); - std::vector compile_time_constant_flags(expressions.size()); + std::vector arg_must_be_compile_time_constant(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); @@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); - if (arg.type == DT_RESOURCE) { - return errors::InvalidArgument( - "Resource as function argument is not yet implemented."); - } else if (expressions[i]->has_constant_value()) { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); - } else if (compile_time_constant_flags[i]) { - arg.kind = XlaCompiler::Argument::kConstant; - TF_RET_CHECK(expressions[i]->resource() == nullptr) - << "Input with resource is not yet implemented."; - TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( - expressions[i]->handle())); - TF_ASSIGN_OR_RETURN(auto literal, - client->ComputeConstant(constant_graph)); - TF_RETURN_IF_ERROR( - LiteralToHostTensor(literal, arg.type, &arg.constant_value)); - } else { - arg.kind = XlaCompiler::Argument::kParameter; + switch (expressions[i]->kind()) { + case XlaExpression::Kind::kConstant: + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + break; + case XlaExpression::Kind::kXlaOp: + if (arg_must_be_compile_time_constant[i]) { + TF_ASSIGN_OR_RETURN(absl::optional value, + expressions[i]->ResolveConstant(client)); + if (!value.has_value()) { + return errors::InvalidArgument( + "Argument to function must be a compile-time constant, but " + "unable to resolve argument value to a constant."); + } + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *value; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + break; + case XlaExpression::Kind::kResource: + return errors::Unimplemented( + "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument("Invalid function argument"); } } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 9ee4178f5c213e919255bb33e9b15800a77256e6..d85b4f5ae0cb9c7d2476158a5830f921742ae980 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -178,6 +178,32 @@ tf_kernel_library( ], ) +# A separate cc_library for resampler_ops is needed because resampler is in +# contrib/, and thus the declaration of resampler cannot be pulled into the deps +# of xla_ops. Therefore, resampler_ops is its own cc_library target, and its +# corresponding tf_kernel_library is defined in contrib/resampler/BUILD. +cc_library( + name = "resampler_ops", + srcs = ["resampler_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + cc_library( name = "conv_op_helpers", srcs = ["conv_op_helpers.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 276d744c096f8996c774964204feaa3762bdb844..2db2514397deca39e6874cf994532a20d2186316 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel { } const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; - if (arg.resource() != nullptr) { - ctx->SetResourceOutput(0, arg.resource()); - } else if (arg.has_constant_value()) { - ctx->SetConstantOutput(0, arg.constant_value()); - } else { - ctx->SetOutput(0, arg.handle()); - } + OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, + errors::InvalidArgument("Invalid/missing argument expression")); + ctx->SetOutputExpression(0, arg); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 9fa57b76f8e3649c03fe41f39638b88cb065ed0e..c022284fec6bc91951170e243ea3609c8d5d0c43 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), errors::InvalidArgument("In[", i, "] must be a vector.", in_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); - - BCast::Vec vec; - for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(literal.Get({i})); - } - shapes.push_back(vec); + std::vector vec; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec)); + + shapes.push_back(BCast::Vec(vec.begin(), vec.end())); } BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index ad85940920ebb82e72331516e3fe46c79f853892..3e398fff951a211f5af42d26983bc7473bddde63 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,10 +21,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -57,8 +60,6 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::XlaBuilder* builder = ctx->builder(); - xla::Shape uniform_shape; int class_dimension; if (num_samples > 1) { @@ -83,16 +84,16 @@ class CategoricalOp : public XlaOpKernel { xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); class_dimension = 1; } - xla::XlaOp uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type)); + xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. auto softmax_entries = - xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + xla::Sub(logits, log_uniforms, /*broadcast_dimensions=*/{0, class_dimension}); xla::PrimitiveType xla_output_type; @@ -107,6 +108,16 @@ class CategoricalOp : public XlaOpKernel { ctx->SetOutput(0, argmax); } + virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, + xla::PrimitiveType type, + XlaOpKernelContext* ctx) { + xla::XlaBuilder* builder = ctx->builder(); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + return xla::Log(-xla::Log(uniforms)); + } + private: TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); }; @@ -115,5 +126,48 @@ class CategoricalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), CategoricalOp); +class StatelessCategoricalOp : public CategoricalOp { + public: + explicit StatelessCategoricalOp(OpKernelConstruction* ctx) + : CategoricalOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, + XlaOpKernelContext* ctx) override { + xla::XlaOp seed = ctx->Input(2); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::XlaBuilder* builder = ctx->builder(); + if (uniform_shape.element_type() == xla::BF16) { + uniform_shape.set_element_type(xla::F32); + } + auto uniforms = xla::StatelessRngUniform( + {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), + XlaHelpers::One(builder, DT_FLOAT)); + return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + CategoricalOp::Compile(ctx); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); +}; + +REGISTER_XLA_OP(Name("StatelessMultinomial") + .CompileTimeConstantInput("num_samples") + .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("Tseed", DT_INT32), + StatelessCategoricalOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e28755dd73bf6f8e1518dda2494cade79b7db22e..cd7c7f4a82df7a65829787efcb1fd2f77870e945 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -45,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_tensor_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); - // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = literal.Get({}); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_tensor_shape.DebugString())); + int64 concat_dim; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); std::vector values; std::vector shapes; @@ -63,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel { const TensorShape& input_shape = shapes[0]; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(ctx, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -75,14 +72,11 @@ class ConcatBaseOp : public XlaOpKernel { // elements. std::vector input_data; int output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; - const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( - ctx, - in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), + ctx, in_shape.dims() == input_dims, errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", input_shape.DebugString(), " vs. shape[", i, @@ -131,11 +125,10 @@ class ConcatOffsetOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_shape = ctx->InputShape(0); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_shape.DebugString())); for (int i = 1; i < ctx->num_inputs(); ++i) { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), errors::InvalidArgument("input ", i, @@ -162,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel { // [0, 5, 0, 0] const int32 N = ctx->num_inputs() - 1; const TensorShape inp0_shape = ctx->InputShape(1); - xla::Literal inp0_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); - const int64 dims = inp0_shape.num_elements(); + std::vector inp0_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims)); + const int64 inp0_rank = inp0_shape.num_elements(); - xla::Literal concat_dim_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = concat_dim_literal.Get({}); + int64 cdim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); - VLOG(1) << "ConcatOffset " << cdim << "," << dims; - int32 axis = cdim < 0 ? cdim + dims : cdim; - OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), + VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; + int32 axis = cdim < 0 ? cdim + inp0_rank : cdim; + OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), errors::InvalidArgument("Concat dim is out of range: ", axis, - " vs. ", dims)); + " vs. ", inp0_rank)); int32 offset = 0; for (int i = 0; i < N; ++i) { const TensorShape inp_shape = ctx->InputShape(1 + i); - OP_REQUIRES(ctx, dims == inp_shape.num_elements(), - errors::InvalidArgument("input ", i, " should contain ", dims, - " elements, but got ", + OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), + errors::InvalidArgument("input ", i, " should contain ", + inp0_rank, " elements, but got ", inp_shape.num_elements())); - xla::Literal inp_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); + std::vector inp_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims)); - Tensor out_constant(DT_INT32, TensorShape({dims})); + Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); auto out_vec = out_constant.vec(); - for (int64 j = 0; j < dims; ++j) { + for (int64 j = 0; j < inp0_rank; ++j) { if (j == axis) { out_vec(j) = offset; - offset += inp_literal.Get({j}); + offset += inp_dims[j]; } else { - const int32 inp0_element = inp0_literal.Get({j}); - const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES(ctx, (inp0_element == inp_element), + const int32 inp0_element = inp0_dims[j]; + const int32 inp_element = inp_dims[j]; + OP_REQUIRES(ctx, inp0_element == inp_element, errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", inp0_element, " vs. ", inp_element)); diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 2628ef8e2454976aeff3859fa5dc1d8e106f32e1..dff8af800229b9605bb93e0498bc5e5cf012f244 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); - if (proto_.dtype() == DT_STRING) { - LOG(WARNING) << "Not computing Const of type DT_STRING"; - ctx->SetInvalidOutput(0); - return; - } xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c9a1be494066e4f935a1d818bc86c86333e34fae..b1046fcc0001a3eb450c82a8545e2cfdf4e43fd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index c68b0bfd7961892294c2931e5c4c44de534a7740..29687c7b82f92d9f336854c4575746589c63b64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e9bdb15aa0c57fd95530798f87c68e2e63e84e1d..35e0625dbb0d4c696d36cce642d6f50f1d220c45 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -33,39 +34,20 @@ class FillOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // The output of this Op is a tensor of shape 'dims_shape' with each // element set to the scalar 'dims_literal'. - const TensorShape dims_shape = ctx->InputShape(0); - const TensorShape value_shape = ctx->InputShape(1); + const TensorShape dims_shape = ctx->InputShape("dims"); + const TensorShape value_shape = ctx->InputShape("value"); OP_REQUIRES( - ctx, IsLegacyVector(dims_shape), + ctx, TensorShapeUtils::IsVector(dims_shape), errors::InvalidArgument("dims must be a vector of int32, got shape ", dims_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(value_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape), errors::InvalidArgument("value must be a scalar, got shape ", value_shape.DebugString())); - // Evaluate the 'dims' constant input, reshaping to a vector if it - // was a 'legacy' vector (secretly a scalar). - xla::Literal dims_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped( - 0, {dims_shape.num_elements()}, &dims_literal)); - // Convert the dims literal into a vector that we can pass to - // XlaBuilder. - std::vector broadcast; - broadcast.reserve(dims_literal.shape().dimensions(0)); - for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(dims_literal.Get({i})); - } - // Look up the value input, reshaping to a scalar if it was a - // 'legacy' scalar (secretly a vector). - xla::XlaOp data = ctx->Input(1); - if (value_shape.dims() > 0) { - CHECK_EQ(value_shape.dims(), 1); - data = xla::Reshape(data, {}); - } - // Emit the actual computation, which broadcasts the scalar to the - // desired shape. - auto result = xla::Broadcast(data, broadcast); + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims)); + auto result = xla::Broadcast(ctx->Input("value"), dims); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index d069373086a6dcd6e7901abfc63d851a731da321..42bf4b06e5da7c6f99ad32ae36131dffd124d103 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -48,9 +48,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // We require that the dimension argument is a constant, since it lets us // dispatch to a specialized custom-call function without any run-time // overhead, when compiling ahead-of-time. - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = literal.Get({}); + int64 dim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), @@ -120,6 +119,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { ", but got shape: ", input_shape.DebugString())); } + const DataType dtype = output_type(0); + xla::PrimitiveType output_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type)); + output = xla::ConvertElementType(output, output_type); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 8dfd7de591c4a3c4768dd60b41e03d294ad49397..a99b74565dab4587bee999e3d73340ff58d21f77 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6..4f980b6d14ed667bdf4756ed740894098cae5919 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 4833a9662dd12ca72b5715373b549af105625d45..f6b8534f4d7c537e5b708ee000e00cb92123584b 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel { for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); - TF_ASSIGN_OR_RETURN(int64 lhs_padding, - pad_literal.GetIntegralAsS64({dimno, 0})); - TF_ASSIGN_OR_RETURN(int64 rhs_padding, - pad_literal.GetIntegralAsS64({dimno, 1})); + int64 lhs_padding = pad_literal.Get({dimno, 0}); + int64 rhs_padding = pad_literal.Get({dimno, 1}); int64 dim_size = original_shape.dimensions(dimno); // Padding amounts on each side must be no more than the size of the @@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); @@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel { TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::XlaBuilder* b = ctx->builder(); - auto in0 = ctx->Input(0); + auto in0 = ctx->Input("input"); xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 3f5445b4821b0918f3b220ecfe2be20bccb33dc2..36ea70ac392ff18fb52d400efa886533f8335eba 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel { explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); const int dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); - if (fixed_dims == 0) { + xla::XlaOp input = ctx->Input("input"); + if (dims == 0) { // Tensor is rank 0. Return it unchanged. - ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(0, input); return; } - // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::PaddingConfig config; - for (int i = 0; i < fixed_dims; ++i) { + for (int i = 0; i < dims; ++i) { auto* dim = config.add_dimensions(); - int before = pad_literal.Get({i, 0}); - int after = pad_literal.Get({i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument( "Paddings must be non-negative: ", before, " ", after)); @@ -73,12 +70,13 @@ class PadOp : public XlaOpKernel { // PadV2 added a "constant_values" input that indicates the pad value. xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), - errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")), + errors::InvalidArgument("constant_values must be a scalar.")); + ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(input, zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 415ce9b77ffeac8a6a5f3c23537afb16c1d3567c..8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 107fa62967a55dffcfff8728b65338564e5202d2..132160de707911f26389034e16236985bb18e6ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -113,11 +113,21 @@ class MeanOp : public XlaReductionOp { xla::Add(scalar_lhs, scalar_rhs); } - xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) override { - auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), - num_elements_reduced); + xla::XlaOp BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce) override { + if (dimensions_to_reduce.empty()) { + return reduce_output; + } + auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + for (int i = 1; i < dimensions_to_reduce.size(); i++) { + auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + divisor = xla::Mul(divisor, size); + } + xla::PrimitiveType type; + TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); + divisor = xla::ConvertElementType(divisor, type); return reduce_output / divisor; } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 466e79828d111ee7cadcf713703e8f252c63e62c..8f1667df5b72e9ecf97e5771670ef209dee287a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,13 +48,14 @@ class XlaReductionOp : public XlaOpKernel { const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired - // computation should be added to 'builder'. Argument 'reduce_output' is the - // output of the reduction. 'num_elements_reduced' is the number of elements - // that contributed to the reduction. Returns the transformed reduction - // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced); + // computation should be added to 'builder'. Argument 'input' is the original + // input of the reduction; 'reduce_output' is the output of the reduction. + // Returns the transformed reduction output, Defaults to returning + // 'reduce_output' unchanged. + virtual xla::XlaOp BuildFinalizer( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 118f2798d559f43acb7f6394a7337426164325ef..e96cabbb853be744dbba7f19fbbd227bb52ebb06 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -37,9 +37,10 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, // Unless BuildFinalizer is overridden the reduction has no // finalizer. -xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) { +xla::XlaOp XlaReductionOp::BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& /*input*/, + const xla::XlaOp& reduce_output, + const std::vector& /*dimensions_to_reduce*/) { return reduce_output; } @@ -71,7 +72,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; - int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { int64 index = axes[i]; OP_REQUIRES(ctx, @@ -82,7 +82,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { index = (index + data_shape.dims()) % data_shape.dims(); bitmap[index] = true; xla_axes.push_back(index); - num_elements_reduced *= data_shape.dim_size(index); } std::vector final_shape; @@ -119,7 +118,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto finalized = BuildFinalizer(b, data, deconverted, xla_axes); auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a8f33c8f39e47d7bd1f59413be880c51d273cf1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -0,0 +1,587 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +using xla::XlaOp; + +// Calculates the bilinear weight tensor, given basis ratio (px, py) of the +// sampling position: +// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] +// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2]. +// +// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4]. +XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, + const TensorShape warp_shape, + xla::PrimitiveType xla_type) { + auto first_term = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}}); + first_term = xla::ConvertElementType(first_term, xla_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector broadcast_dims(warp_dims.begin(), warp_dims.end() - 1); + broadcast_dims.push_back(4); + broadcast_dims.push_back(2); + + const int64 broadcast_dims_size = broadcast_dims.size(); + + std::vector last_two_dims_indices = {(broadcast_dims_size - 2), + (broadcast_dims_size - 1)}; + + xla::Shape broadcast_shape = + xla::ShapeUtil::MakeShape(xla_type, broadcast_dims); + + auto broadcast_first_term = + xla::BroadcastInDim(first_term, broadcast_shape, last_two_dims_indices); + + // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, + // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the + // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last + // dimension. + std::vector ratio_broadcast_indices(broadcast_dims.size()); + std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0); + ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); + + auto broadcast_ratio = + xla::BroadcastInDim(ratio, broadcast_shape, ratio_broadcast_indices); + + auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; + + // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to + // flip the signs of the second and the third term. + auto sign_change = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}}); + sign_change = xla::ConvertElementType(sign_change, xla_type); + + auto broadcast_sign_change = + xla::BroadcastInDim(sign_change, broadcast_shape, last_two_dims_indices); + + auto flipped = first_term_subtract_weights * broadcast_sign_change; + + // Build up the final bilinear weight tensor by multiply reduction, which + // gives: + // [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] + // for each 4 neighboring pixels where px and py are the weight of the target + // pixel we are sampling from. + return xla::Reduce( + flipped, xla::One(ctx->builder(), xla_type), + xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()), + {broadcast_dims_size - 1}); +} + +// Concatenates the batch indices to the (x, y) coordinate indices. +// This is done by first creating an Iota tensor that represents the current +// batch it is in, then concatenate with the givin (coordinate) indices. +// +// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where +// the last dimension of size 3 in turn is [batch_number, x, y]. +// The [batch_number, x, y] dimension is needed because the indices +// [x,y] alone cannot allow the xla::Gather operation to gather from the input +// data, which is of dimension [batch, height(y), width(x), channel] with +// 'batch' being the first dimension. +XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, + const TensorShape& warp_shape) { + // We need to create an iota tensor with the same batch dimension. + std::vector dimensions; + for (auto dim : warp_shape) { + dimensions.push_back(dim.size); + } + // Except the last dimension, which is of size 1. + dimensions.back() = 1; + + auto batch_indices = + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + /*iota_dimension=*/0); + + return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); +} + +// Gathers the 2x2 neighbors of the input starting_indices, and return a +// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels]. +// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last +// dimension of size 3 is (batch_no, x, y). +XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices, + int64 data_channels, int warp_dims) { + xla::GatherDimensionNumbers gather_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the offset dimensions for Gather is the last 3 dimensions. + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1); + // The last dimension of 'gather_indices' is the starting indices for gather. + gather_dim_numbers.set_index_vector_dim(warp_dims - 1); + gather_dim_numbers.add_collapsed_slice_dims(0); + gather_dim_numbers.add_start_index_map(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when gathering. + gather_dim_numbers.add_start_index_map(2); + gather_dim_numbers.add_start_index_map(1); + // Data dimensions are [batch, x, y, channel]. + // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels]. + auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers, + /*slice_sizes=*/{1, 2, 2, data_channels}); + // Collapse the ...,2,2,... dimensions into ...,4,... + return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims}); +} + +// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the +// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels]. +// This function can also be seen as the inverse of 'Gather2by2Neighbors'. +XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, + XlaOp updates, int64 warp_dims, + xla::PrimitiveType xla_type) { + xla::ScatterDimensionNumbers scatter_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the update window dimensions is the last 3 dimensions. + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1); + scatter_dim_numbers.set_index_vector_dim(warp_dims - 1); + + scatter_dim_numbers.add_inserted_window_dims(0); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when scattering. + scatter_dim_numbers.add_scatter_dims_to_operand_dims(2); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(1); + + return xla::Scatter(grad_data, indices, updates, + xla::CreateScalarAddComputation(xla_type, ctx->builder()), + scatter_dim_numbers); +} + +// Build computation the backprop into input 'data'. +// Where input: +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// +// Output: +// scatter-add to each 2x2 grad_data neighbor: +// grad_data[fx, fy, chan] += output_grad * dx * dy +// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy +// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) +// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) +// where (dx, dy) is (1 - ratio). +XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, xla::PrimitiveType warp_type, + TensorShape warp_shape, int64 data_channels, + xla::Shape data_shape) { + // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + std::vector reshaped_weights_dims = warp_dims_without_last_dims; + // Reshape the last dimension of size 4 to two dimensions [2, 2]. + reshaped_weights_dims.push_back(2); + reshaped_weights_dims.push_back(2); + std::vector reshape_dims(warp_shape.dims()); + std::iota(reshape_dims.begin(), reshape_dims.end(), 0); + // The dimension is [batch, dim_0,..., dim_n, 2, 2]. + auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims, + /*new_sizes=*/reshaped_weights_dims); + + std::vector weights_with_channels_dims = reshaped_weights_dims; + weights_with_channels_dims.push_back(data_channels); + auto weights_with_channels_shape = + xla::ShapeUtil::MakeShape(warp_type, weights_with_channels_dims); + std::vector reshaped_weights_indices(reshaped_weights_dims.size()); + std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), + 0); + + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. + auto broadcast_reshaped_weights = xla::BroadcastInDim( + reshaped_weights, weights_with_channels_shape, reshaped_weights_indices); + + std::vector grad_output_indices(warp_dims_without_last_dims.size()); + std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); + grad_output_indices.push_back(weights_with_channels_dims.size() - 1); + XlaOp broadcast_grad_output = xla::BroadcastInDim( + grad_output, weights_with_channels_shape, grad_output_indices); + + auto grad_output_multiply_weights = + broadcast_grad_output * broadcast_reshaped_weights; + + auto grad_data = xla::ConstantLiteral( + ctx->builder(), xla::Literal::CreateFromShape(data_shape)); + + return ScatterToGradData(ctx, grad_data, gather_indices, + grad_output_multiply_weights, warp_shape.dims(), + warp_type); +} + +// Build computation for the backprop into input 'warp'. +// Where input: +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data is of dimension [batch, x, y, channel] +// +// Output (simplified by ignoring the batch dimensions): +// Since the forward path has: +// output = dot(weights * neighbors) +// The backprop into warp will therefore be: +// grad_warp = output_grad * d_output / d_warp +// = output_grad * (d_weights / d_warp * neighbors + d_neighbors / +// d_warp * weight) +// Where: +// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py] +// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px] +// and +// d_neighbors / d_warp_x = 0 +// +// Therefore: +// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) +// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) +// +// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// bottom right corner in a 2x2 neighborhood. +XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, XlaOp data, + TensorShape warp_shape, int64 data_channels, + xla::PrimitiveType data_type) { + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; + neighbor_broadcast_dims.push_back(4); + + // With dimension [batch, dim_0, ...dim_n, 4] + auto neighbor_broadcast_shape = + xla::ShapeUtil::MakeShape(data_type, neighbor_broadcast_dims); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + const int64 last_warp_dim = warp_shape.dims() - 1; + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + // img_cxcy - img_fxcy + auto bottom_right_minus_bottom_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, 0, -1, 1}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxfy - img_fxfy + auto top_right_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 1, 0, 0}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxcy - img_cxfy + auto bottom_right_minus_top_right = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, -1, 0, 1}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_fxcy - img_fxfy + auto bottom_left_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 0, 1, 0}), data_type), + neighbor_broadcast_shape, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // Slice out x and y. + auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1, + /*stride=*/1, /*dimno=*/last_warp_dim); + auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2, + /*stride=*/1, /*dimno=*/last_warp_dim); + + // Build 1 - y and 1 - x. + auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y; + auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x; + + auto x_before_reduce = + grad_output * weight_y * bottom_right_minus_bottom_left + + one_minus_y * top_right_minus_top_left; + + std::vector reshaped_sizes = warp_dims_without_last_dims; + reshaped_sizes.push_back(1); + + std::vector reshaped_dims(warp_dims_without_last_dims.size()); + std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0); + + // Reduce-add along the channel dimension. + auto x_result = + xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type), + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + // Reshape before concatenating with y values. + XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes); + + auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right + + one_minus_x * bottom_left_minus_top_left; + // Reduce-add along the channel dimension. + auto y_result = + xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type), + + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes); + + return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y}, + last_warp_dim); +} + +class ResamplerOp : public XlaOpKernel { + public: + explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape.DebugString())); + const int64 data_channels = data_shape.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + const int64 last_warp_dim = warp_shape.dims() - 1; + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); + + XlaOp data = ctx->Input("data"); + XlaOp warp = ctx->Input("warp"); + + // Find the coordinates of the top left corner for the 2x2 region to be + // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the + // last dimension of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + // Dimensions are [batch, dim_0, ... dim_n, 2]. + XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type); + + // Obtain the bilinear blending weights, the dimension is [batch, dim_0, + // ...dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type); + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + // The dimension is [batch, dim_0, ...dim_n, data_channels]. + auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, + /*precision_config=*/nullptr); + + // Handle out of boundary cases by constructing a predicate mask array based + // on the in-bound condition, and output 0 for the blended pixel value if + // out-bound. The dimension is the same as top_left: [batch, dim_0, + // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate. + + auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp)); + + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dim_size(2) - 1), + /*height=*/static_cast(data_shape.dim_size(1) - 1)}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, + ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which + // is the dimension of the result: + // [batch, dim_0, ...dim_n, data_channels]. + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(data_channels); + xla::Shape broadcasted_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::PRED, result_dims); + + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, broadcasted_shape, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims); + auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros); + + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("Resampler"), ResamplerOp); + +class ResamplerGradOp : public XlaOpKernel { + public: + explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + } + + // TODO(b/112295522): note that sampling from image boundary is not currently + // being handled properly. + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape_tf = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape_tf.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape_tf.DebugString())); + const int64 data_channels = data_shape_tf.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); + + TensorShape output_grad_shape = ctx->InputShape("grad_output"); + OP_REQUIRES( + ctx, output_grad_shape.dims() >= 2, + errors::InvalidArgument("output_grad must be at least 2-dimensional", + output_grad_shape.DebugString())); + + // Dimensions are [batch, x, y, channel]. + XlaOp data = ctx->Input("data"); + xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf); + + // Dimensions are [batch, dim_0, ...dim_n, 2]. + XlaOp warp = ctx->Input("warp"); + // Dimensions are [batch, dim_0, ...dim_n, channel]. + XlaOp grad_output = ctx->Input("grad_output"); + + // Find the top left corner coordinate for the region to be sampled from. + // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension + // of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + // Dimensions are [batch, dim_0, ... dim_n, 2] + XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); + + // Indices for gathering neighboring pixels. + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + auto grad_data = + CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, + warp_shape, data_channels, data_shape); + + auto grad_warp = + CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, + warp_shape, data_channels, data_type); + + ctx->SetOutput(0, grad_data); + ctx->SetOutput(1, grad_warp); + } +}; + +REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 47a4eac20669c225a653bd3f1f00eeafd0845a42..fa1b6b91710f5507f41f3f69b0715398ae879aaf 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); const TensorShape sizes_shape = ctx->InputShape(1); // Preliminary validation of sizes. - OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e172c649325adb6f7761ce0be141f21e8d545bc1..6970dd0a00641c9f88571561501fb3454fb3eab3 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -46,61 +47,8 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::XlaOp input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - DataType input_type = ctx->input_type(0); - XlaContext& tc = XlaContext::Get(ctx); - - if (input_type == DT_RESOURCE) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - ctx->SetStatus(tc.AddResourceRetval(index_, resource)); - return; - } - - auto is_constant = ctx->builder()->IsConstant(input); - if (!is_constant.ok()) { - ctx->SetStatus(is_constant.status()); - return; - } - - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); - } else { - TensorShape shape = ctx->InputShape(0); - ctx->SetStatus(is_constant.status()); - TensorShape representation_shape; - if (tc.is_entry_computation()) { - xla::StatusOr shape_or_status = - tc.RepresentationShape(shape, ctx->input_type(0)); - if (!shape_or_status.ok()) { - ctx->SetStatus(shape_or_status.status()); - return; - } else { - representation_shape = shape_or_status.ValueOrDie(); - } - } else { - representation_shape = shape; - } - - xla::XlaOp output = input; - if (tc.is_entry_computation()) { - output = xla::Reshape(input, representation_shape.dim_sizes()); - } else { - // The core from which a return value is returned depends on the - // device assignment of the input to the retval. Since we can't change - // the device assignment of "input" at this point, we must always - // introduce an operator here, even if the shape does not change. - // TODO(b/76097077): propagate device assignments onto arguments and - // return values of functions, and then reshape unconditionally. - output = - xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); - } - tc.AddRetval(index_, dtype_, shape, output); - } + XlaContext& xla_context = XlaContext::Get(ctx); + xla_context.SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 56b80cb4a299c07157a166208b96ad369075aa83..2ceadaf79c5cef35ad50aa84a0d66a46527a6458 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel { } // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); - std::vector revdims(x_shape.dims()); - std::copy(lax.data().begin(), lax.data().end(), - revdims.begin()); - std::vector dimensions; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax)); + std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { - if (revdims[d]) { + if (lax.Get({d})) { dimensions.push_back(d); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 7ff3e9163811434e8d621795c22bf8304ba7a1ed..d7b38e86cc985d608116488f9e76756a8e904f9c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 379f4aeb0fc7bbfff59696726f5af231b1294c49..b1fa2915d59e4e5e2f2523e20e9a37898d087117 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,31 +30,6 @@ limitations under the License. namespace tensorflow { namespace { -template -Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = literal.Get({}); - return Status::OK(); -} - -Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - switch (literal.shape().element_type()) { - case xla::S32: - *value = literal.Get({}); - break; - case xla::S64: - *value = literal.Get({}); - break; - default: - return errors::InvalidArgument("Invalid argument type for argument", - index); - } - return Status::OK(); -} - // The type-specific part of the implementation of Range. template xla::StatusOr CreateRangeTensor( @@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel { const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape limit_in_shape = ctx->InputShape(1); const TensorShape delta_in_shape = ctx->InputShape(2); - OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape), errors::InvalidArgument("limit must be a scalar, not shape ", limit_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape), errors::InvalidArgument("delta must be a scalar, not shape ", delta_in_shape.DebugString())); xla::Literal start, limit, delta; @@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel { explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape start_in_shape = ctx->InputShape(0); - const TensorShape stop_in_shape = ctx->InputShape(1); - const TensorShape num_in_shape = ctx->InputShape(2); + const TensorShape start_in_shape = ctx->InputShape("start"); + const TensorShape stop_in_shape = ctx->InputShape("stop"); + const TensorShape num_in_shape = ctx->InputShape("num"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); @@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel { DataType type = ctx->input_type(0); int64 num; - OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num)); OP_REQUIRES(ctx, num > 0, errors::InvalidArgument("Requires num > 0: ", num)); Tensor out_constant(type, TensorShape({num})); + xla::Literal start_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal)); + xla::Literal stop_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal)); + switch (type) { case DT_FLOAT: { - float start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + float start = start_literal.GetFirstElement(); + float stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; @@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel { break; } case DT_DOUBLE: { - double start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + double start = start_literal.GetFirstElement(); + double stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 37b026aeb058b464acd74264766f187b787914aa..12830816ec16c9797f0fe4d8f3f13f5a8176161d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { @@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel { explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dim_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape dim_shape = ctx->InputShape("dim"); - // TODO(phawkins): the standard implementation of ExpandDimsOp seems to - // accept legacy scalars, even when they should be forbidden by the graphdef - // version. - OP_REQUIRES(ctx, dim_shape.num_elements() == 1, + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); + OP_REQUIRES(ctx, dims.size() == 1, errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); - - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - - int dim = literal.data()[0]; + int dim = dims[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -148,7 +144,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 34980ead81815c2818a259e096148fcce9c9a3b1..88da64e5a217a0c026106f03cb26958f6738446c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" @@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel { OP_REQUIRES( ctx, - IsLegacyVector(begin_tensor_shape) && - IsLegacyVector(size_tensor_shape) && + TensorShapeUtils::IsVector(begin_tensor_shape) && + TensorShapeUtils::IsVector(size_tensor_shape) && begin_tensor_shape.num_elements() == input_shape.dims() && size_tensor_shape.num_elements() == input_shape.dims(), errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 230a343f7966f19cda44991a747287ba675fca4c..7a0e240400b344ab25743997ce3baad81bd5f476 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); - const TensorShape index_shape = ctx->InputShape(0); + const TensorShape split_dim_shape = ctx->InputShape("split_dim"); const TensorShape input_shape = ctx->InputShape(1); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - - int32 split_dim_orig; - if (index_shape.dims() == 0) { - split_dim_orig = literal_index.Get({}); - } else { - OP_REQUIRES( - ctx, index_shape.dims() == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - OP_REQUIRES( - ctx, index_shape.dim_size(0) == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - split_dim_orig = literal_index.Get({0}); - } + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(split_dim_shape), + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_shape.dims())); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() : split_dim_orig; OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), @@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel { // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; - std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); OP_REQUIRES(ctx, split_size_shape.dims() == 1 && @@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel { split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); // Get the dimension of this split. - xla::Literal split_size_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + std::vector split_sizes; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes)); for (int i = 0; i < num_split; ++i) { - int slice_size; - slice_size = split_size_literal.Get({i}); + int64 slice_size = split_sizes[i]; if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, @@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel { i)); neg_one_dim = i; } else { - split_sizes_vec[i] = slice_size; total_split_size += slice_size; } } @@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel { total_split_size)); if (neg_one_dim >= 0) { - split_sizes_vec[neg_one_dim] = + split_sizes[neg_one_dim] = input_shape.dim_size(split_dim) - total_split_size; } @@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel { std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); - int slice_size = split_sizes_vec[i]; + int slice_size = split_sizes[i]; output_shape.set_dim(split_dim, slice_size); // Slice out the ith split from the split dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d79cdad9fa2dabe1e236741955499b845064148f..7b96b43ad834c28aa0283c5ef4ac516618ca5134 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp); +REGISTER_XLA_OP( + Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(), + StackOp); class StackPushOp : public XlaOpKernel { public: @@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); }; -REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); +REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp); class StackPopOp : public XlaOpKernel { public: @@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); }; -REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); +REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp); class StackCloseOp : public XlaOpKernel { public: @@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); }; -REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); +REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 7b2cd5a5b08d80284d172c2ed5d6be4c355e76e0..e1c764f3d5c28cf0d812519e4a16786e1f2d3a3a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" @@ -44,7 +45,7 @@ class TileOp : public XlaOpKernel { const TensorShape multiples_shape = ctx->InputShape("multiples"); OP_REQUIRES( - ctx, IsLegacyVector(multiples_shape), + ctx, TensorShapeUtils::IsVector(multiples_shape), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples_shape.DebugString())); OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 48a211942d7c4405bf68189e641ee184db36b0ba..c9b324a243e4cc3ec64daa3ca0d285336a0d0154 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel { : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape perm_tensor_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("x"); + const TensorShape perm_tensor_shape = ctx->InputShape("perm"); // Preliminary validation of sizes. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), @@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel { ". But input(1) is a vector of size ", perm_tensor_shape.num_elements())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); - - std::vector perm(dims); - std::copy(literal.data().begin(), literal.data().end(), - perm.begin()); + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm)); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { - const int32 d = perm[i]; + const int64 d = perm[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); @@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel { xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - transposed = ctx->Input(0); + transposed = ctx->Input("x"); } else { - transposed = xla::Transpose(ctx->Input(0), transposed_order); + transposed = xla::Transpose(ctx->Input("x"), transposed_order); } // Conjugate the transposed result if this is ConjugateTransposeOp. diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0bdfc05726105e2d18362a691cbe2aab00bf77f3..a0ea6422d732b00fc1b8cf855d9c9ad603b87c82 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); XLAJIT_MAKE_UNARY(Neg, -x); -// Implements Banker's rounding: numbers that are equidistant between two -// integers are rounded towards even. -xla::XlaOp RoundToEven(xla::XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); - - auto round_val = xla::Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); -} - -XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 20103ec3ae00b57723e05326dbbb1b0f6e1a671a..67d08290033361f16dfff42b06af9b253e84963a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal)); + return literal.Clone(); +} + Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { xla::Shape xla_shape; diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 1db7470ee2a839099454b772d4833492e033bc92..a153dddee6127ff9c0858220f2d8a735ab3f0e19 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,11 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); + +// Returns a Literal with the contents of 'host_tensor', backed by its own +// storage (i.e., not reusing 'host_tensor's buffers.) +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor); + // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // owned by 'host_tensor', but is mutable via the xla::Literal methods. Status HostTensorToMutableBorrowingLiteral( diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 8b559c87506a6b519e2ad1d1bf22ab30c0ff161d..c9f486edc8d30954619db0967c988fe8e26938de 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ "//learning/deepmind/public/wavenet/python:__subpackages__", + "//learning/deepmind/research/alphastar:__subpackages__", "//learning/tfx:__subpackages__", "//tensorflow:internal", ], diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index cb7843850c352eee2e55baf52a0c4445dc861d7b..ddb284966eeb97cc7c9d3ed77fb313e567975e59 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto( "XLACompilationDevice::MakeTensorFromProto should not be called"); } -XlaExpression::XlaExpression() = default; - -void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index a6e78825334fec748be5fee80669649df699d2fb..de6a3356e05d8ab45c269d7c6c653853d2c63a79 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" @@ -38,8 +35,8 @@ class XlaCompilationAllocator; // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that -// backs each Tensor with metadata indicating the computation the -// Tensor represents. +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. // // We deliberately don't register a device factory because we *never* // want placement to put Ops on a compilation device. The device is created @@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// A XlaExpression wraps an XLA computation. Each Tensor on an -// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the XlaOp. Each -// expression is either a constant, or a function of previously-compiled -// expressions. -class XlaExpression { - public: - XlaExpression(); - - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::XlaOp& h); - const xla::XlaOp& handle() const { return handle_; } - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - void set_resource(XlaResource* resource) { resource_ = resource; } - XlaResource* resource() const { return resource_; } - - private: - // The XLA handle of the expression's computation. - xla::XlaOp handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_ = false; - Tensor constant_value_; - - XlaResource* resource_ = nullptr; // Not owned. -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 425e769346ffcbc548495d93cb7adc779f860110..66206909a92fddbac4e77e5d2d8164fcbb46f317 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -287,11 +287,6 @@ class XlaCompiledCpuFunction { // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - // - // For now we need to keep around the args_ array because there is code that - // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using buffer_table_ as the sole storage for the - // arguments. const int32* const arg_index_table_; // The number of incoming arguments. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e177a5f07f5607a0f9de75e6a999ee492cd9db4f..8036bc684401ff31c07ac381098e05fb8b7ee76a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,10 +36,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -48,7 +51,7 @@ namespace { // Checks that arguments `args` match types `types`. Status CheckSignature(const DataTypeVector& types, - const std::vector& args) { + absl::Span args) { if (args.size() != types.size()) { return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); @@ -63,6 +66,262 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } +// Uses the _Arg and _Retval nodes in the graph to determine a core assignment +// for each argument and return value. +xla::StatusOr, std::map>> +ComputeArgAndRetvalCores(const Graph& graph) { + auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + std::map arg_cores; + std::map retval_cores; + for (const Node* n : graph.nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Arg index"; + arg_cores[index] = core; + } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Retval index"; + TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); + retval_cores[index] = core; + } + } + return std::make_pair(std::move(arg_cores), std::move(retval_cores)); +} + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status status; + auto step_container = absl::make_unique( + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return Status::OK(); +} + +// Builds the XLA computation. +// - `args` is the list of input arguments +// - `retvals` is the list of retvals produced by _Retval operators, in index +// order. +// - `args_core` and `retval_cores` are mapping from arg/return indices to core +// assignments. +// - If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// - Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// - `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a ResourceUpdate, whose `index` is the index of a +// resource variable argument to the computation to be updated, and `type` is +// the type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& retvals, + const std::map& arg_cores, const std::map& retval_cores, + const std::vector>& resources, + std::unique_ptr token_output, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, + std::vector* resource_updates, + xla::Shape* output_shape) { + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); + + // Builds a no-op XLA computation. We need to set the sharding of outputs, but + // cannot change the sharding of the existing output op. To do this, we build + // a new identity op to which shardings can be applied. + auto identity_op = [builder](xla::XlaOp op) { + return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); + }; + + std::vector elems; + elems.reserve(retvals.size()); + + // Keeps track of which retvals have layout to update. The first element is + // the output index, second element is the new layout. + std::vector> retval_to_update_layout; + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + const XlaExpression& retval = retvals[i]; + output.type = retval.dtype(); + switch (retval.kind()) { + case XlaExpression::Kind::kConstant: + output.is_constant = true; + output.constant_value = retval.constant_value(); + output.shape = output.constant_value.shape(); + break; + + case XlaExpression::Kind::kXlaOp: { + output.is_constant = false; + TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); + xla::XlaOp value = retval.handle(); + auto it = retval_cores.find(i); + xla::XlaScopedShardingAssignment assign_sharding( + builder, it == retval_cores.end() + ? absl::optional() + : xla::sharding_builder::AssignDevice(it->second)); + if (shape_representation_fn) { + // If there is a shape representation function, reshape the output + // tensor to the shape given by the representation shape function. + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( + output.shape, output.type)); + value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + } else if (it != retval_cores.end()) { + // Apply the sharding to the output, if there is a core assignment. + value = identity_op(value); + } + + elems.push_back(value); + break; + } + + case XlaExpression::Kind::kResource: + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); + output.shape = retval.resource()->shape(); + break; + + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument( + "Invalid expression returned by computation. " + "This probably means a return value was not set."); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + DCHECK_LT(resource->arg_num(), args.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + auto it = arg_cores.find(resource->arg_num()); + const int core = it == arg_cores.end() ? -1 : it->second; + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::XlaScopedShardingAssignment assign_sharding( + builder, core == -1 ? absl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Ensures the correct sharding is applied to the output. + handle = identity_op(handle); + + elems.push_back(handle); + } + } + + // If we have token output, append it as the last one. + if (token_output) { + elems.push_back(*token_output); + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + + TF_ASSIGN_OR_RETURN(const auto& program_shape, + computation->GetProgramShape()); + *output_shape = program_shape.result(); + // Update the output layout to the layout of retval. + for (auto& update : retval_to_update_layout) { + if (!always_return_tuple && elems.size() == 1) { + *output_shape->mutable_layout() = update.second; + continue; + } + + xla::Shape* output_sub_shape = + xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); + *output_sub_shape->mutable_layout() = update.second; + } + return Status::OK(); +} + } // namespace bool XlaCompiler::Argument::operator==( @@ -83,6 +342,39 @@ bool XlaCompiler::Argument::operator==( return constant_value.tensor_data() == other.constant_value.tensor_data(); } +string XlaCompiler::Argument::HumanString() const { + string common; + if (!name.empty()) { + common = absl::StrCat(" name=", name); + } + absl::StrAppend(&common, " type=", DataTypeString(type), + " shape=", shape.DebugString()); + switch (kind) { + case kInvalid: + return "invalid"; + case kConstant: + return absl::StrCat("kind=constant", common, + " value=", constant_value.DebugString()); + case kResource: { + string output = absl::StrCat("kind=resource", common, " resource_kind=", + XlaResource::KindToString(resource_kind), + " initialized=", initialized); + if (tensor_array_size >= 0) { + absl::StrAppend(&output, " tensor_array_size=", tensor_array_size); + } + if (!tensor_array_gradients.empty()) { + absl::StrAppend(&output, " tensor_array_gradients=", + absl::StrJoin(tensor_array_gradients, ",")); + } + return output; + } + case kParameter: + return absl::StrCat("kind=parameter", common); + case kToken: + return absl::StrCat("token", common); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -110,8 +402,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = [](const TensorShape& shape, - DataType type) { return shape; }; + options_.shape_representation_fn = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -171,15 +468,16 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - std::vector args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction( + const XlaCompiler::CompileOptions& options, const NameAttrList& function, + absl::Span args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; - auto it = cache_.find({function_id, args}); + const std::vector arg_vector(args.begin(), args.end()); + auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; return Status::OK(); @@ -212,14 +510,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kArgOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kRetOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -235,7 +535,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; - cache_[{function_id, args}] = *result; + cache_[{function_id, arg_vector}] = *result; return Status::OK(); } @@ -247,25 +547,24 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape; if (is_entry_computation) { TF_ASSIGN_OR_RETURN( - shape, options_.shape_representation_fn(arg.shape, arg.type)); + *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); } else { - shape = arg.shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); } - return TensorShapeToXLAShape(arg.type, shape, xla_shape); + return Status::OK(); } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, - options_.shape_representation_fn(arg.shape, arg.type)); - return TensorShapeToXLAShape(arg.type, representation_shape, - xla_shape); + TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( + arg.shape, arg.type)); + + return Status::OK(); } case XlaResource::kTensorArray: { if (arg.tensor_array_size < 0) { @@ -314,175 +613,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } -namespace { - -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the - // resource manager takes ownership via Create, and unrefs via Cleanup. We - // explicitly add a reference to ensure the refcount at entry is maintained at - // all exit points; Create and Cleanup are always called in this function. - // - // The Executor requires us to use ScopedStepContainer. We wrap it in a - // unique_ptr so we can capture the cleanup status in the end. - xla_context->Ref(); - Status status; - auto step_container = absl::make_unique( - step_id, [&status, device](const string& name) { - status = device->resource_manager()->Cleanup(name); - }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); - - GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); - TF_RETURN_IF_ERROR(graph_compiler.Compile()); - // Explicitly clean up the step container, to capture the cleanup status. - step_container.reset(); - return Status::OK(); -} - -// Builds the XLA computation. -// `args` is the list of input arguments, `retvals` is the list of retvals -// produced by _Retval operators, in index order. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - std::unique_ptr token_output, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else if (retval.resource() != nullptr) { - output.is_constant = false; - output.input_index = retval.resource()->arg_num(); - } else { - output.is_constant = false; - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - // Attach a common operator name as metadata. This has no semantic effect — it - // merely makes the HLO graph more readable when visualized via TensorBoard, - // since TensorBoard forms groups out of operators with similar names. - xla::OpMetadata retval_metadata; - retval_metadata.set_op_name("XLA_Retvals"); - builder->SetOpMetadata(retval_metadata); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = - modified || - !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); - } - } - - // If we have token output, append it as the last one. - if (token_output) { - elems.push_back(*token_output); - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. We *always* form a tuple here to ensure that - // the output value is the last thing added into the XLA computation, even - // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); - if (!always_return_tuple && elems.size() == 1) { - xla::GetTupleElement(tuple, 0); - } - builder->ClearOpMetadata(); - - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - std::vector* arg_cores, std::vector* arg_expressions, + const std::map& arg_cores, + std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); - *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -504,7 +644,7 @@ Status XlaCompiler::BuildArguments( arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); - arg_expression.set_resource(resource); + arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } @@ -516,7 +656,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kConstant: - arg_expression.set_constant_value(arg.constant_value); + arg_expression = XlaExpression::Constant(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -541,26 +681,6 @@ Status XlaCompiler::BuildArguments( *input_shapes = arg_shapes; } - // Use the _Arg nodes in the graph to resolve core assignments. - for (const Node* n : graph.nodes()) { - if (absl::string_view(n->type_string()) != "_Arg") continue; - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - TF_RET_CHECK(index >= 0 && index < args.size()) - << "_Arg out of bounds: " << index << " vs " << args.size(); - TF_ASSIGN_OR_RETURN( - auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == - xla::OpSharding::Type::OpSharding_Type_MAXIMAL); - const int core = sharding.value().tile_assignment_devices(0); - if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { - (*arg_cores)[index] = core; - } - } - } - // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -576,11 +696,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); for (int64 parameter : *input_mapping) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; + auto it = arg_cores.find(parameter); + const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); + xla::sharding_builder::AssignDevice(core); } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); @@ -589,7 +708,8 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -597,7 +717,8 @@ Status XlaCompiler::BuildArguments( } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -632,14 +753,14 @@ Status XlaCompiler::BuildArguments( // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. if (is_entry_computation) { - arg_expression.set_handle( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression = XlaExpression::XlaOp( + xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); } else { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; case XlaCompiler::Argument::kToken: { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; } case XlaCompiler::Argument::kConstant: @@ -653,46 +774,48 @@ Status XlaCompiler::BuildArguments( } Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, string const& name, - OpKernelContext* ctx, const std::vector& args, - CompilationResult* result) { + const XlaCompiler::CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result) { // TODO(b/74182462): We implement this by creating a new dummy Graph including // _Arg nodes, and let CompileGraph walk it. This could be optimized. std::unique_ptr graph(new Graph(OpRegistry::Global())); Status status; // First create the actual node we care about computing. - Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + Node* main_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); // Create dummy _Arg nodes. Link these to `node` and also via a control // dependency edge to the _SOURCE node. - for (int64 i = 0; i < ctx->num_inputs(); ++i) { + for (int64 i = 0; i < args.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); - Status status = NodeBuilder(name, "_Arg") - .ControlInput(graph->source_node()) - .Attr("T", ctx->input_dtype(i)) - .Attr("index", i) - .Finalize(graph.get(), &node); + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); graph->AddEdge(node, 0, main_node, i); } // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < ctx->num_outputs(); ++i) { + for (int64 i = 0; i < result_types.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); - Status status = NodeBuilder(name, "_Retval") + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) .Input(main_node, i) - .Attr("T", ctx->expected_output_dtype(i)) + .Attr("T", result_types[i]) .Attr("index", i) .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, name, std::move(graph), args, result); + return CompileGraph(options, node_def.name(), std::move(graph), args, result); } namespace { @@ -747,12 +870,38 @@ Status ValidateGraph(const Graph* graph, return Status::OK(); } +// Converts the value of any expressions whose values are known at compile-time +// to constants. +Status ResolveConstantExpressionsToConstants( + xla::Client* client, absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kXlaOp) { + TF_ASSIGN_OR_RETURN(absl::optional constant, + expression.ResolveConstant(client)); + if (constant.has_value()) { + expression = XlaExpression::Constant(*constant); + } + } + } + return Status::OK(); +} + +void ConvertConstantsToExpressions(xla::XlaBuilder* builder, + absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kConstant) { + expression = + XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); + } + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; @@ -774,13 +923,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); - std::vector real_args(args); + std::vector real_args(args.begin(), args.end()); int token_input_index = -1; std::unique_ptr token_output; if (options.add_token_input_output) { @@ -792,10 +940,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, real_args.push_back(token_arg); } + std::map arg_cores; + std::map retval_cores; + TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), + ComputeArgAndRetvalCores(*graph)); + std::vector arg_expressions; - std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -843,28 +995,27 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); + std::vector retvals = context->retvals(); + if (options.resolve_compile_time_constants) { + TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals))); + } else { + ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + } TF_RETURN_IF_ERROR(BuildComputation( - real_args, arg_cores, context->retvals(), context->resources(), - std::move(token_output), options.return_updated_values_for_all_resources, + real_args, retvals, arg_cores, retval_cores, context->resources(), + std::move(token_output), + options.is_entry_computation ? options_.shape_representation_fn + : ShapeRepresentationFn{}, + options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, - &result->resource_updates)); + &result->resource_updates, &result->xla_output_shape)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - - // Compute the XLA output shape, if there is a computation with non-constant - // outputs. - TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, - client()->GetComputationShape(*result->computation)); - - result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " - << xla::ShapeUtil::HumanString(result->xla_output_shape); - - // Tensorflow expects a major-to-minor order of results. - xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - + << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2cc603a58016a509fafdf6f95423dd6c0864cce3..63426124686e1b92a3534b7e365b8282008b8455 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" @@ -118,7 +121,7 @@ class XlaCompiler { // The type of the argument. If the argument is a resource, this // is the type of the variable's value, not DT_RESOURCE. - DataType type; + DataType type = DT_INVALID; // The shape of the argument. For: // * a parameter: the shape of the parameter. @@ -155,6 +158,9 @@ class XlaCompiler { std::set tensor_array_gradients; bool operator==(const Argument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -259,8 +265,7 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function(const TensorShape&, - DataType)> + typedef std::function(const TensorShape&, DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. @@ -316,22 +321,23 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - std::vector args, CompilationResult* result); + absl::Span args, + CompilationResult* result); // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result); - // Compiles a single Op, given by an OpKernelContext, into an + // Compiles a single Op, given by `node_def`, into an // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. - Status CompileSingleOp(const CompileOptions& options, string const& name, - OpKernelContext* ctx, - const std::vector& args, + Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result); // Returns the shape of the XLA parameter for an argument 'arg'. @@ -411,7 +417,8 @@ class XlaCompiler { Status BuildArguments(const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, std::vector* arg_cores, + XlaContext* context, + const std::map& arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4ef154f856b9284a6c97f2c3072b198ccfb5e517..eba5d77efabd752f8476c27e95610343c54ea460 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -909,6 +911,82 @@ TEST_F(XlaCompilerTest, Variables) { RunAndCheckVariablesComputation(client_, result); } +TEST_F(XlaCompilerTest, ResultLayoutSingle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET"), a, 0); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + auto compile_options = XlaCompiler::CompileOptions(); + compile_options.always_return_tuple = false; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), + args, &result)); + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); +} + +TEST_F(XlaCompilerTest, ResultLayoutMultiple) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0); + auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", + std::move(graph), args, &result)); + xla::Shape result_shape = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -1018,9 +1096,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); @@ -1086,9 +1166,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 20e1ee2ddb390edd3a7d881022c68072a69193dc..43095fbb47351617a0de12a088c947106ccaa641 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } -// This is called by the Retval Op to associate a computed value -// with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { - VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); +void XlaContext::SetRetval(int index, const XlaExpression& expression) { + if (retvals_.size() <= index) { + retvals_.resize(index + 1); } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; + retvals_[index] = expression; } -Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal) { - VLOG(1) << "Adding retval index " << retval_index - << " with non-data-dependent tensor to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; - return Status::OK(); -} - -Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { - VLOG(1) << "Adding retval index " << retval_index << " with resource " - << resource->name() << ":" << resource->shape().DebugString() - << " to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - XlaExpression e; - e.set_resource(resource); - retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; - return Status::OK(); -} - -xla::XlaBuilder* XlaContext::builder() { return builder_; } - Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, @@ -133,7 +93,7 @@ Status XlaContext::CreateResource( return Status::OK(); } -xla::StatusOr XlaContext::RepresentationShape( +xla::StatusOr XlaContext::RepresentationShape( const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e97dd67af0ef09ef33dbc7a4d19743b..dbfd344c9bad8a5d05abb6a3b902ed3baebbe02a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -46,9 +46,8 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. @@ -57,37 +56,19 @@ class XlaContext : public ResourceBase { XlaCompiler* compiler() const { return compiler_; } // Returns the XlaBuilder that Ops use for compiling new expressions. - xla::XlaBuilder* builder(); + xla::XlaBuilder* builder() { return builder_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector& retvals() { return retvals_; } - - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); + const std::vector& retvals() { return retvals_; } - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal); - - // As for Retval, but for return values that are resource handles. - Status AddResourceRetval(int retval_index, XlaResource* resource); + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -105,8 +86,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -140,31 +121,19 @@ class XlaContext : public ResourceBase { // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; - // If true, constant return values are returned as Tensors instead of - // run-time computation outputs. - const bool resolve_compile_time_constants_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - - // A function that describes how the shapes of - // a) argument and return value, for entry computations - // b) variables, for all computations, - // should be represented in XLA. Parameters/return values will be shaped - // according to this function, and reshaped back to/from their declared shapes - // for computations. Must be non-null. - const std::function(const TensorShape&, DataType)>* + // Describes the on-host shapes of parameters and return values. Also see: + // XlaDevice::Options::shape_representation_fn. + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca0309166b7c73d1a5a818091e2a30fa112a4de4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_expression.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() = default; + +XlaExpression XlaExpression::Invalid() { + XlaExpression e; + e.kind_ = Kind::kInvalid; + return e; +} + +XlaExpression XlaExpression::Constant(Tensor value) { + XlaExpression e; + e.kind_ = Kind::kConstant; + e.dtype_ = value.dtype(); + e.constant_value_ = value; + return e; +} + +XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { + XlaExpression e; + e.kind_ = Kind::kXlaOp; + e.dtype_ = dtype; + e.handle_ = value; + return e; +} + +XlaExpression XlaExpression::Resource(XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + return e; +} + +string XlaExpression::HumanString() const { + switch (kind_) { + case Kind::kInvalid: + return "invalid"; + case Kind::kConstant: + return "constant"; + case Kind::kXlaOp: + return "xla_op"; + case Kind::kResource: + return "resource"; + } +} + +xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + switch (kind_) { + case Kind::kConstant: { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToBorrowingLiteral(constant_value_, &literal)); + return xla::ConstantLiteral(builder, literal); + } + case Kind::kXlaOp: + if (builder != handle_.builder()) { + return errors::InvalidArgument( + "Mismatched builders in XlaExpression::AsXlaOp"); + } + return handle_; + default: + return errors::InvalidArgument("AsXlaOp called on XlaExpression: ", + HumanString()); + } + }); +} + +xla::StatusOr> XlaExpression::ResolveConstant( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: + return {constant_value()}; + case Kind::kXlaOp: + break; + case Kind::kResource: + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveConstant called on XlaExpression: ", HumanString()); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + handle().builder()->IsConstant(handle())); + if (!is_constant) return {absl::nullopt}; + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildConstantSubGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor)); + return {tensor}; +} + +xla::StatusOr XlaExpression::GetShape() const { + switch (kind_) { + case Kind::kConstant: + return constant_value().shape(); + case Kind::kXlaOp: { + TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, + handle().builder()->GetShape(handle())); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + return shape; + } + case Kind::kResource: + return TensorShape({}); + case Kind::kInvalid: + return errors::InvalidArgument( + "GetShape() called on invalid XlaExpression"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 0000000000000000000000000000000000000000..bed6761d362a98d344003c1edea342e68c31ef07 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + const Tensor& constant_value() const { return constant_value_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kConstant expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + xla::StatusOr> ResolveConstant( + xla::Client* client) const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + xla::StatusOr GetShape() const; + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp. + xla::XlaOp handle_; + + // The value of the constant, if kind_ == kConstant. + Tensor constant_value_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..84202c931390f2d68f6d381aef0752bfff00a53d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class XlaExpressionTest : public ::testing::Test { + protected: + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + builder_ = absl::make_unique("acomputation"); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); + non_constant_op_ = xla::Parameter( + builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); + resource_ = absl::make_unique( + XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), + DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), + /*tensor_array_multiple_writes_aggregate=*/false); + } + + xla::Client* client_; + std::unique_ptr builder_; + Tensor constant_; + xla::XlaOp op_; + xla::XlaOp non_constant_op_; + std::unique_ptr resource_; +}; + +TEST_F(XlaExpressionTest, Kind) { + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind()); + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind()); + EXPECT_TRUE(XlaExpression::Kind::kConstant == + XlaExpression::Constant(constant_).kind()); + EXPECT_TRUE(XlaExpression::Kind::kXlaOp == + XlaExpression::XlaOp(op_, DT_INT32).kind()); + EXPECT_TRUE(XlaExpression::Kind::kResource == + XlaExpression::Resource(resource_.get()).kind()); +} + +TEST_F(XlaExpressionTest, HumanString) { + EXPECT_EQ("invalid", XlaExpression().HumanString()); + EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString()); + EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString()); + EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString()); + EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString()); +} + +TEST_F(XlaExpressionTest, AsXlaOp) { + xla::XlaOp op_as_op = + XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get()); + EXPECT_TRUE(op_.IsIdenticalTo(op_as_op)); + + xla::XlaOp const_as_op = + XlaExpression::Constant(constant_).AsXlaOp(builder_.get()); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, + builder_->BuildConstantSubGraph(const_as_op)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, + client_->ComputeConstant(computation)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), + value)); +} + +TEST_F(XlaExpressionTest, GetShape) { + EXPECT_FALSE(XlaExpression().GetShape().ok()); + EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok()); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape, + XlaExpression::Resource(resource_.get()).GetShape()); + EXPECT_EQ(TensorShape({}), resource_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape, + XlaExpression::XlaOp(op_, DT_INT32).GetShape()); + EXPECT_EQ(TensorShape({}), op_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape, + XlaExpression::Constant(constant_).GetShape()); + EXPECT_EQ(TensorShape({}), constant_shape); +} + +TEST_F(XlaExpressionTest, ResolveConstant) { + EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); + EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); + EXPECT_FALSE( + XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional op_constant, + XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); + ASSERT_TRUE(op_constant.has_value()); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + + TF_ASSERT_OK_AND_ASSIGN(absl::optional op_nonconstant, + XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) + .ResolveConstant(client_)); + EXPECT_FALSE(op_nonconstant.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional constant_constant, + XlaExpression::Constant(constant_).ResolveConstant(client_)); + ASSERT_TRUE(constant_constant.has_value()); + test::ExpectTensorEqual(constant_, *constant_constant); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da..af378bc95c096082ff5cd963b9d6156f4351cd8d 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dd3498ef7aa242d3ad946cae5f60bc2c8853a342..8dd8def0549f2b39d4c9863bb535f19703c3ef22 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().valid() || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle(); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); return expression; } -// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. -static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { +// Assigns an XlaExpression to a tensor on an XLA compilation device. +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK(!expression->handle().valid()); - return const_cast(expression); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; } -// Retrieves the XlaOp from an input Tensor to an Op. This computation was -// constructed by an Op that executed previously and created the output Tensor -// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. -static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); +const XlaExpression& XlaOpKernelContext::InputExpression(int index) { + return *CastExpressionFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(int index) { - return GetComputationFromTensor(context_->input(index)); +const XlaExpression& XlaOpKernelContext::InputExpression( + absl::string_view name) { + return *CastExpressionFromTensor(GetInputTensorByName(name)); } -const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { - return GetComputationFromTensor(GetInputTensorByName(name)); +xla::XlaOp XlaOpKernelContext::Input(int index) { + return InputExpression(index).AsXlaOp(builder()); +} + +xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { + return InputExpression(name).AsXlaOp(builder()); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -125,77 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name, Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { - const Tensor& tensor = context_->input(index); - TensorShape new_shape(new_dims); - if (tensor.NumElements() != new_shape.num_elements()) { - return errors::InvalidArgument( - context_->op_kernel().name(), " input ", index, " has shape ", - tensor.shape().DebugString(), - " but was asked to be reshaped to incompatible shape ", - new_shape.DebugString()); - } - const XlaExpression* expression = CastExpressionFromTensor(tensor); - - auto copy_tensor_to_literal = [](const Tensor& tensor, - xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); - }; - - // If the tensor has a known constant value, there is no need to invoke XLA. - if (expression->has_constant_value()) { - Tensor temp(tensor.dtype()); - if (!temp.CopyFrom(expression->constant_value(), new_shape)) { - // This should never happen. The constant should have a shape compatible - // with the enclosing Tensor. - return errors::Internal("Incompatible shapes in ConstantInputReshaped."); - } - - return copy_tensor_to_literal(temp, constant_literal); - } - - // Make sure we treat zero-element tensors as constant. - if (new_shape.num_elements() == 0) { - Tensor temp(tensor.dtype(), new_shape); - - return copy_tensor_to_literal(temp, constant_literal); - } - - xla::XlaOp handle = expression->handle(); - if (new_shape != tensor.shape()) { - // Reshape the handle to the desired shape. - handle = xla::Reshape(handle, new_shape.dim_sizes()); - } - - // The XLA layout is specified minor to major, and TensorFlow's minor - // dimension is the last one. - std::vector layout_indices(new_shape.dims()); - std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - - xla::StatusOr is_constant = builder()->IsConstant(handle); - if (!is_constant.ok()) { - Status status = is_constant.status(); + XlaExpression e = InputExpression(index); + xla::StatusOr> constant_or_status = + e.ResolveConstant(compiler()->client()); + if (!constant_or_status.ok()) { + Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); return status; } - - if (!is_constant.ValueOrDie()) { + absl::optional constant = constant_or_status.ValueOrDie(); + if (!constant.has_value()) { return errors::InvalidArgument( "Input ", index, " to ", context_->op_kernel().type_string(), " operator must be a compile-time constant.\n" @@ -208,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped( "stateful operation such as a random number generator."); } - // Ask the XLA compiler to evaluate the data handle to a literal. - xla::StatusOr constant_graph = - builder()->BuildConstantSubGraph(handle); - if (!constant_graph.ok()) { - return errors::Internal( - "Error getting a compile-time constant graph for ", - context_->op_kernel().name(), " input ", index, - ".\nError: ", constant_graph.status().error_message()); - } - xla::StatusOr computed = compiler()->client()->ComputeConstant( - constant_graph.ValueOrDie(), &layout); - if (!computed.ok()) { - return errors::Internal("Error evaluating ", context_->op_kernel().name(), - " input ", index, - " as a compile-time constant.\nError: ", - computed.status().error_message()); + Tensor temp(constant->dtype()); + if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + constant->shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + TensorShape(new_dims).DebugString()); } - *constant_literal = std::move(computed).ValueOrDie(); + TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); return Status::OK(); } @@ -322,6 +259,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + absl::string_view name, std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -372,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(GetComputationFromTensor(input)); + handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -413,9 +359,12 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, XlaContext& xla_context = XlaContext::Get(ctx); TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, + xla::Shape representation_shape, xla_context.RepresentationShape(variable->shape(), variable->type())); - if (representation_shape == variable->shape()) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); + if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { *value = variable->value(); } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); @@ -455,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, - Tensor** output) { - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - if (expected_output_dtype(index) == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in its - // value. - // TODO(jpienaar): This should be refactored to stop masquerading - // XlaExpressions as Tensors. - *output = new Tensor(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, *output)); - context_->set_output(index, **output); - } else { - TensorShape tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); - TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); +void XlaOpKernelContext::SetOutputExpression(int index, + const XlaExpression& expression) { + Status status = [&] { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + Tensor* output = nullptr; + // Provides a special behavior for DT_VARIANT: a variant is treated as + // DT_UINT8 scalar as the type to allow mapping for variant to more generic + // types. + if (expression.dtype() == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in + // its value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, output)); + context_->set_output(index, *output); + } else { + TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); + TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); + } + AssignExpressionToTensor(output, expression); + return Status::OK(); + }(); + if (!status.ok()) { + SetStatus(status); } - return Status::OK(); } void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { - // Makes the host Tensor that will refer to the expression. - Tensor* output = nullptr; - auto shape_or = builder()->GetShape(handle); - if (!shape_or.ok()) { - SetStatus(shape_or.status()); - return; - } - - OP_REQUIRES_OK(context_, - allocate_output(index, shape_or.ValueOrDie(), &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); + SetOutputExpression( + index, + XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); } void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { - const TensorShape& shape = constant.shape(); - - xla::BorrowingLiteral literal; - OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - - xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); - CHECK(handle.valid()); - - // Make the Tensor that will refer to the expression. - Tensor* output = nullptr; - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); - expression->set_constant_value(constant); -} - -void XlaOpKernelContext::SetInvalidOutput(int index) { - Tensor* output = nullptr; - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape({}), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::XlaOp handle; - expression->set_handle(handle); + SetOutputExpression(index, XlaExpression::Constant(constant)); } void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { - Tensor* output = nullptr; - // The shape of the output tensor is the shape of the resource itself - // (i.e., a scalar), not the shape of the resource's value. - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape(), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_resource(resource); + SetOutputExpression(index, XlaExpression::Resource(resource)); } Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { @@ -570,10 +482,13 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, xla_context.RepresentationShape(shape, type)); - if (shape != representation_shape) { - handle = xla::Reshape(handle, representation_shape.dim_sizes()); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { + handle = xla::Reshape(handle, + xla::AsInt64Slice(representation_shape.dimensions())); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index aa00a454968ad29495e34dc080e55b62bb0b5f7b..c06efa2c474c5ec3cb5d75d94ba15d4096faa085 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -88,9 +88,9 @@ class XlaOpKernelContext { // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::XlaOp& Input(int index); + xla::XlaOp Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(absl::string_view name); + xla::XlaOp Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -111,14 +111,6 @@ class XlaOpKernelContext { Status ConstantInput(int index, xla::Literal* constant_literal); Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); - // Evaluates input `index`, reshapes it to `new_shape` if new_shape != - // InputShape(index), and stores it in `*constant_literal`. If the input - // cannot be evaluated, e.g., because it depends on unbound parameters, - // returns a non-Ok status. If InputShape(index).num_elements() != - // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, absl::Span new_dims, - xla::Literal* constant_literal); - // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); Status ConstantInputAsIntScalar(absl::string_view name, int64* out); @@ -134,6 +126,8 @@ class XlaOpKernelContext { // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. Status ConstantInputReshapedToIntVector(int index, std::vector* out); + Status ConstantInputReshapedToIntVector(absl::string_view name, + std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); @@ -148,6 +142,10 @@ class XlaOpKernelContext { Status ConstantInputList(absl::string_view name, std::vector* literals); + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + // Outputs int num_outputs() const { return context_->num_outputs(); } @@ -165,9 +163,8 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output `index` to an invalid value. - // Any subsequent attempt to consume this output will cause an error. - void SetInvalidOutput(int index); + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } @@ -255,10 +252,13 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); - // Wraps OpKernelContext's allocate_output method while providing special - // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the - // type to allow mapping for variant to more generic types. - Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + Status ConstantInputReshaped(int index, absl::Span new_dims, + xla::Literal* constant_literal); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9f00de708cc5aceb2c1e397663bc3bba8705bda4..14237df69081016817fbd1a5332f22996e7f264d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -129,21 +130,26 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + cpu_global_jit + ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = true; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; registration.compile_resource_ops = false; } return nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 45a40c0acc07805b422591fd7ea3fcb131db8471..0bdd4a1085445420a5147756daac4a54f4725f11 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -66,19 +66,26 @@ class XlaOpRegistry { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. string compilation_device_name; - // Do operators assigned to this device require compilation? - bool requires_compilation; - - // If !requires_compilation, should we try to JIT operators on this device - // when XLA JIT compilation is enabled globally via the SessionOptions? - // (It is still possible to explicitly mark operators to JIT compile, even - // if enable_jit_by_default is false.) - bool enable_jit_by_default; + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? bool compile_resource_ops = false; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 63b09c8f02a60e91576544d13227d29f56d3e88c..a322eb9015e829fd468133f3de6c12aad7e4ff74 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,6 +26,19 @@ limitations under the License. namespace tensorflow { +/*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) { + switch (kind) { + case XlaResource::kInvalid: + return "invalid"; + case XlaResource::kVariable: + return "variable"; + case XlaResource::kStack: + return "stack"; + case XlaResource::kTensorArray: + return "tensorarray"; + } +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index aa9ce1b171f11ea0de4db0123098729c1c97f93a..857b9a928bb824656f637b2b1ca2fc02a1bef139 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -35,6 +36,7 @@ class XlaResource { kTensorArray, kStack, }; + static absl::string_view KindToString(Kind kind); XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d6b60c5f9916520ba7585824171aad1548610da6..d914e97b6bd4506251dc4be504d6ab427590e615 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -68,7 +68,7 @@ cc_library( visibility = [":friends"], deps = [ ":xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", ], ) @@ -735,6 +735,72 @@ tf_cc_test( ], ) +cc_library( + name = "parse_flags_from_env", + srcs = ["parse_flags_from_env.cc"], + hdrs = ["parse_flags_from_env.h"], + deps = + [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "parse_flags_from_env_test", + srcs = ["parse_flags_from_env_test.cc"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "debug_options_flags", + srcs = [ + "debug_options_flags.cc", + "debug_options_parsers.h", + ], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "debug_options_parsers_test", + size = "small", + srcs = [ + "debug_options_parsers.h", + "debug_options_parsers_test.cc", + ], + deps = + [ + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 782c966b4c57672d137569a318fb20ace14d493b..e4aca98f67d50287a83afc6f41a59458f3df2da2 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -104,7 +104,7 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); - auto set = [&array, n1, n2](int64 index, NativeT value) { + auto set = [&array, n2](int64 index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64 i = 0; i < count - 1; ++i) { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 0cbe68d7efd9fe2ea46b312763437e1b8c986d25..42da0ebf4992884187bbe21701a44d8ba2fccd64 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -68,6 +68,7 @@ cc_library( deps = [ ":global_data", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", @@ -76,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -236,13 +236,13 @@ tf_cc_test( deps = [ ":xla_builder", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index f5f8d5c6b1fe265069992fe92acaa229647d4e8c..eef2844e0df6aaf509881535f41493673fbeeee5 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -210,11 +210,10 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const XlaComputation& computation, absl::Span arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteGraphRequest request; +StatusOr Client::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions* execution_options) { + CompileRequest request; *request.mutable_computation() = computation.proto(); if (execution_options == nullptr) { @@ -222,6 +221,34 @@ StatusOr> Client::Execute( } else { *request.mutable_execution_options() = *execution_options; } + if (request.execution_options().device_handles_size() > 1) { + return InvalidArgument( + "Compiling with multiple device handles is not supported. Use " + "'Execute' instead."); + } + + // The argument shapes affect how the computation is compiled. + for (const auto& arg_shape : argument_shapes) { + *request.add_input_shape_with_layout() = arg_shape; + } + + CompileResponse response; + VLOG(1) << "making compile request: " << request.ShortDebugString(); + Status s = stub_->Compile(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + TF_RET_CHECK(response.has_handle()); + return response.handle(); +} + +StatusOr> Client::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { + ExecuteRequest request; + *request.mutable_handle() = handle; for (GlobalData* argument : arguments) { CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); @@ -229,7 +256,7 @@ StatusOr> Client::Execute( ExecuteResponse response; VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->ExecuteGraph(&request, &response); + Status s = stub_->Execute(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -238,15 +265,62 @@ StatusOr> Client::Execute( if (execution_profile != nullptr) { *execution_profile = response.profile(); + } + + return absl::make_unique(stub_, response.output()); +} + +StatusOr> Client::Execute( + const XlaComputation& computation, absl::Span arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + if (execution_options != nullptr && + execution_options->device_handles_size() > 1) { + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); + } + } + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); + } + + // The argument shapes affect how the computation is compiled. + std::vector arg_shapes(arguments.size()); + for (int i = 0; i < arguments.size(); i++) { + TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); + } + + TF_ASSIGN_OR_RETURN(auto handle, + Compile(computation, arg_shapes, execution_options)); + + TF_ASSIGN_OR_RETURN(auto result, + Execute(handle, arguments, execution_profile)); + + if (execution_profile != nullptr) { if (VLOG_IS_ON(1)) { TF_ASSIGN_OR_RETURN( auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); + ExecutionStatsAsString(computation, *execution_profile)); VLOG(1) << execution_stats; } } - return absl::make_unique(stub_, response.output()); + return std::move(result); } StatusOr>> Client::ExecuteParallel( @@ -274,10 +348,11 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { + for (size_t i = 0; i < response.responses_size(); ++i) { outputs.push_back( absl::make_unique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { + if (i < computations.size() && + computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } } @@ -390,8 +465,7 @@ StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); + GetComputationStats(computation, GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 6f4d33c469f1f885cfeef546e3981dc3417ef71f..d0ac4703c632e0e01d3c8911594b46fedf28930d 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -40,6 +40,31 @@ class Client { explicit Client(ServiceInterface* stub); virtual ~Client(); + // Compile the computation with the given argument shapes and returns the + // handle to the compiled executable. The compiled executable is cached on the + // service, and the returned handle can be used for exection without + // re-compile. + // * The shape and layout of the arguments being executed with will affect how + // the computation is compiled. If argument_shapes is empty, the parameters' + // shape and layout will be used in the compilation. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_options.device_handles should be empty. If you need + // non-empty device handles, call 'Execute' instead. + StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions* execution_options = nullptr); + + // Executes the compiled executable for the given handle with the given + // arguments and returns the global data that was produced from the execution. + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global // data that was produced from the execution. // * If execution_options is not nullptr, these options are passed to the diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f833ddcd3235e08e2d0d3c0b9921e96ef871c89e..c5733bc66deb8d55a9186ad1893abaf17ed6909e 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -164,7 +164,6 @@ cc_library( deps = [ ":constants", ":math", - ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", @@ -178,8 +177,9 @@ cc_library( srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ - ":numeric", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", ], @@ -188,10 +188,6 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9..08a887a6e4660cb2528f0ec7244b7ccc540808d2 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -265,6 +265,22 @@ XlaOp Digamma(XlaOp input) { return result; } +// Implements Banker's rounding: numbers that are equidistant between two +// integers are rounded towards even. +XlaOp RoundToEven(XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); + + auto round_val = xla::Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + round_val + one, round_val); +} + // Trigonometric functions. // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index a6cafd42077367bf23ffa1f45eab31c01dc31b16..3f06d04b9ae98b3aa75e68cd07810b2b4c24d280 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -51,6 +51,10 @@ XlaOp Lgamma(XlaOp input); // Computes an approximation of the digamma function. XlaOp Digamma(XlaOp input); +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + // Trigonometric functions // Computes the arc cosine of 'x'. diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 14c259a7fa2a47642663b65d2785e5bbdc040cfd..ae2ea225d1aadd7b3a794eabeca866c498f34760 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -136,5 +136,17 @@ XLA_TEST_F(MathTest, Digamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, RoundToEven) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-1.4, -1.5, -2.5, -0.5, 0, 0.5, 1.5, 2.5, 3.5, 4.5}); + RoundToEven(x); + + std::vector expected = {-1.0, -2.0, -2.0, -0.0, 0, + 0.0, 2.0, 2.0, 4.0, 4.0}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h index efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1..f62fdab4b0e5e84347cfaa1424a8c2e5c58dd3ce 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -22,9 +22,6 @@ limitations under the License. namespace xla { -// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); - // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere // else. XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index c6f68c8ee2f5198017c37abeb9551478f52a99f4..85b9e1827dcef5ed907d893277deb5a52f8f30e9 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index 0475fd9c94f6e390b5169cfe2cbba8eae28ddc18..e8553a08bb014e790822a14e128686b60b8d6b7c 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -23,13 +25,12 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; - int last_dim_size = input_shape.dimensions(last_dim); - XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + Shape iota_shape = + ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); + XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), {broadcast_s32}); + XlaOp sort_result = Sort(Neg(input), {iota_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index fef98c9923096e21a755c6d730de2c7c10852b2d..ebb30d3acc492a115f4980aaa4d2d08f73683864 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -56,5 +56,13 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { + XlaBuilder builder(TestName()); + XlaOp a; + auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); + xla::GetTupleElement(xla::TopK(a, 5), 1); + ComputeAndCompareR1(&builder, {2, 3, 0, 1, 4}, {a_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index f96b6c9c261a9686fb647e3da0dcc933cd1f70df..aaa5d6989eefb94edb8921d13f96e3705aa3e3a4 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -310,4 +310,28 @@ StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } +StatusOr LocalClient::TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal) { + const ::xla::Shape& shape = literal.shape(); + + TF_ASSIGN_OR_RETURN( + ::xla::ScopedShapedBuffer shaped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + shape, backend().memory_allocator(), device_oridinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_oridinal)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.get(), literal, shaped_buffer)); + std::vector<::xla::ScopedShapedBuffer> replicated_buffer; + replicated_buffer.emplace_back(std::move(shaped_buffer)); + ::xla::TransferToServerResponse result; + TF_ASSIGN_OR_RETURN(*result.mutable_data(), + local_service_->RegisterReplicatedBuffers( + std::move(replicated_buffer), + absl::StrCat("TransferToServer literal of shape ", + ::xla::ShapeUtil::HumanString(shape)))); + + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9..ddb36680e8b185b053368baffa6f1d5cac50dc07 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -60,8 +60,8 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. // - // The given ExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. Status ValidateExecutionOptions( const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -69,8 +69,8 @@ class LocalExecutable { // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. // - // The given ServiceExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -114,8 +114,8 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr> Compile( const XlaComputation& computation, const absl::Span argument_layouts, @@ -129,6 +129,10 @@ class LocalClient : public Client { const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); + // Transfer the BorrowingLiteral to the device with the given ordinal. + StatusOr TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal); + // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index f9c23b44810a52ae4dd40cc838e6cb575cb44445..f508ffb9c958ecfae7aea2c232e04001bd826a19 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -239,6 +239,19 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, visited->insert(op_handle); } +Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, + int64 target_dim_num) { + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( + DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, + dynamic_size_param_index}, + DynamicParameterBinding::DynamicDimension{ + target_param_num, target_param_index, target_dim_num})); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -297,6 +310,9 @@ StatusOr XlaBuilder::Build(int64 root_id) { } module->add_computations()->Swap(&entry); + *(module->mutable_dynamic_parameter_binding()) = + dynamic_parameter_binding_.ToProto(); + // Clear data held by this builder. this->instructions_.clear(); this->handle_to_index_.clear(); @@ -2305,6 +2321,19 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, }); } +XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension)); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, + {operand}); + }); +} + StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(first_error_); @@ -3158,4 +3187,8 @@ XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { return builder->Iota(shape, iota_dimension); } +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) { + return operand.builder()->GetDimensionSize(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 908a616b4ead8820b5df991c3bc0b2f6724087ef..78c90dbccc486370377408d54406f4a896f60816 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -263,35 +264,30 @@ class XlaBuilder { // evaluating the computation. StatusOr IsConstant(const XlaOp& operand) const; + // Sets up binding which indicates that the `target_dim_num` in the subshape + // `target_param_index` of parameter `target_param_num` is a dynamic dimension + // and its real dynamic size is represented by `dynamic_param_index` in + // parameter `dynamic_param_num`. + // + // TODO(b/119520625): Remove this API once we have more dynamic shape infra + // ready. + Status SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, int64 target_dim_num); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. + // Description for the methods below can be found in the corresponding public + // functions section in this file. + XlaOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - // Enqueues a constant with the value of the given literal onto the - // computation. XlaOp ConstantLiteral(const LiteralSlice& literal); - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. template XlaOp ConstantR0(NativeT value); template @@ -321,181 +317,78 @@ class XlaBuilder { template XlaOp ConstantR4FromArray4D(const Array4D& values); - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. template XlaOp ConstantR1(int64 length, NativeT value); - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, const absl::Span broadcast_dimensions); - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, absl::Span new_sizes); - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice XlaOp Slice(const XlaOp& operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. XlaOp ConcatInDim(absl::Span operands, int64 dimension); - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. void Trace(const string& tag, const XlaOp& operand); - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - // Enqueues a tuple-creation instruction onto the computation. XlaOp Tuple(absl::Span elements); - // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); - // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -503,8 +396,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, @@ -512,8 +403,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -521,8 +410,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -532,80 +419,53 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config = ""); - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const string& outfeed_config); - // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, absl::Span operands); - // Enqueues a custom call instruction onto the computation. XlaOp CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions = {}); - // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); - // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -624,32 +484,23 @@ class XlaBuilder { XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Reduces several arrays simultaneously among the provided dimensions, given - // "computation" as a reduction operator. XlaOp Reduce(absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, @@ -659,48 +510,22 @@ class XlaBuilder { absl::Span window_dilations, absl::Span> padding); - // Returns the sum of the operand value within each subgroup of replicas. All - // replicas supply one input to the sum and all replicas receive the resulting - // sum for each subgroup. XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups = {}); - // Enqueues an operation that do an AllReduce of the operand cross cores. Here - // AllReduce means doing a reduction on the input operand cross cores and then - // broadcasting the reduction result to those cores. The reduction function is - // defined by `computation`, which should be a commutative computation on - // scalars, e.g., add, min, or max. The way that AllReduce is applied is - // configured by: - // - // - `replica_groups`: each ReplicaGroup contains a list of replica id. If - // empty, all replicas belong to one group. Allreduce will be applied within - // subgroups. For example, we have 4 replicas, then - // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, - // replica 1 and 3 are in subgroup 1. - // - // - `channel_id`: for Allreduce nodes from different modules, if they have - // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will - // not be applied cross modules. - // - // TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); - // Enqueues an operation that do an Alltoall of the operand cross cores. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - // Enqueues an operation that do an CollectivePermute of the operand cross - // cores. XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -708,8 +533,6 @@ class XlaBuilder { const XlaOp& init_value, const XlaComputation& scatter); - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -717,222 +540,126 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); - // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, absl::Span broadcast_dimensions = {}); - // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); - // Enqueues an expm1 instruction onto the computation. XlaOp Expm1(const XlaOp& operand); - // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); - // Enqueues a ceil instruction onto the computation. XlaOp Ceil(const XlaOp& operand); - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. XlaOp Round(const XlaOp& operand); - // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); - // Enqueues an log1p instruction (log(x+1)) onto the computation. XlaOp Log1p(const XlaOp& operand); - // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); - // Enqueues a count leading zeros instruction onto the computation. XlaOp Clz(const XlaOp& operand); - // Enqueues a cosine instruction onto the computation. XlaOp Cos(const XlaOp& operand); - // Enqueues a sine instruction onto the computation. XlaOp Sin(const XlaOp& operand); - // Enqueues a tanh instruction onto the computation. XlaOp Tanh(const XlaOp& operand); - // Enqueues a real-part instruction onto the computation. XlaOp Real(const XlaOp& operand); - // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. XlaOp IsFinite(const XlaOp& operand); - // Enqueues an iota operation onto the computation. XlaOp Iota(const Shape& shape, int64 iota_dimension); - // Enqueues a rank-1 iota operation onto the computation. XlaOp Iota(PrimitiveType type, int64 size); - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); - // Enqueues a transpose instruction onto the computation. XlaOp Transpose(const XlaOp& operand, absl::Span permutation); - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - // Enqueues a sort (as increasing order) instruction onto the computation. - // If only keys are provided: - // * If the keys are an rank-1 tensor (an array), the result is a sorted array - // of keys, in ascending order. - // * If the keys have higher rank, the keys are sorted along the provided - // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension - // value of 0 will indepenently sort every column, and a dimension value of 1 - // will independently sort each row. If no dimension number is provided, then - // the last dimension is chosen by default. - // - // If both keys and values are provided: - // * The keys and all values must be tensors with the same dimensions. The - // element types of the tensors may be different. - // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and tensors with their - // corresponding values as the other elements. XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); - // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - // Enqueues a map instruction onto the computation. XlaOp Map(absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands = {}); - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); - // Enqueues a while node onto the computation. XlaOp While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init); - // Enqueues a conditional node onto the computation. XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation); - // Enqueues a ReducePrecision node onto the computation. XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); - // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - // Enqueues a Send node onto the computation for device-to-device - // communication, to send the given operand to a Recv instruction that shares - // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); - // Enqueues a Send node which sends data to the host. XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const ChannelHandle& handle); - // Enqueues a Recv node which receives data from the host. XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp CreateToken(); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp AfterAll(absl::Span tokens); - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index); - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index); + XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1017,6 +744,9 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -1355,6 +1085,8 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1389,6 +1121,7 @@ class XlaScopedShardingAssignment { // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. +// // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. @@ -2129,7 +1862,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& grad_output, float epsilon, int64 feature_index); +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + // Implementation details below this point. +// template XlaOp XlaBuilder::ConstantR0(NativeT value) { diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index dfe5fd5eb23ca51d2a449106a21293405a3dab6f..8aa85c3cd63c9b0aeb55d2cebbb989b6432ac959 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -43,7 +43,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -54,7 +54,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -349,6 +349,15 @@ TEST_F(XlaBuilderTest, CollectivePermute) { EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); } +TEST_F(XlaBuilderTest, GetDimensionSize) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + GetDimensionSize(x, 1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc similarity index 96% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc rename to tensorflow/compiler/xla/debug_options_flags.cc index 3ed3afcfcede20fbf5c7d4f004378817febeb4c7..a40330a9b1fe201b6ec83d1bfe1a21e294e18f55 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -13,17 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace legacy_flags { - namespace { DebugOptions* flag_values; @@ -101,8 +99,8 @@ void AllocateFlags() { [](string comma_separated_values) { auto* extra_options_map = flag_values->mutable_xla_backend_extra_options(); - impl::parse_xla_backend_extra_options(extra_options_map, - comma_separated_values); + parse_xla_backend_extra_options(extra_options_map, + comma_separated_values); return true; }; @@ -111,8 +109,8 @@ void AllocateFlags() { [](string reduce_precision_option_value) { HloReducePrecisionOptions* option_proto = flag_values->add_hlo_reduce_precision_options(); - return impl::parse_xla_reduce_precision_option( - option_proto, reduce_precision_option_value); + return parse_xla_reduce_precision_option(option_proto, + reduce_precision_option_value); }; flag_objects = new std::vector({ @@ -337,7 +335,7 @@ void AllocateFlags() { "behavior to help run tests on the host that run models in parallel " "across multiple devices."), }); - ParseFlagsFromEnv(*flag_objects); + ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } } // namespace @@ -353,5 +351,4 @@ xla::DebugOptions GetDebugOptionsFromFlags() { return *flag_values; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h similarity index 81% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.h rename to tensorflow/compiler/xla/debug_options_flags.h index b53157f59c61cf4e0850e006ad3656f4be63a936..60e59abc2a2e0f1cce3de1afc928f9fe36f75b33 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ #include @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Appends flag definitions for debug options to flag_list. void AppendDebugOptionsFlags(std::vector* flag_list); @@ -32,7 +31,6 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // first. xla::DebugOptions GetDebugOptionsFromFlags(); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/debug_options_parsers.h similarity index 94% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h rename to tensorflow/compiler/xla/debug_options_parsers.h index ee7eb019c07cf898e48886955b18710146644cac..80aadfd5ece0e768afaf1842d2b6c5b11c288b55 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ #include #include "absl/strings/numbers.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla.pb.h" namespace xla { -namespace legacy_flags { -namespace impl { template void parse_xla_backend_extra_options(T* extra_options_map, @@ -140,8 +138,6 @@ inline bool parse_xla_reduce_precision_option( return true; } -} // namespace impl -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/debug_options_parsers_test.cc similarity index 88% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc rename to tensorflow/compiler/xla/debug_options_parsers_test.cc index 6f197aec53c7596e84437a03affa9118f22f5a1d..8003c3496d5df9be2ff8a99bc171972c8e090c43 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/debug_options_parsers_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" #include #include @@ -23,13 +23,12 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace xla { -namespace legacy_flags { // Test that the xla_backend_extra_options flag is parsed correctly. TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { std::unordered_map test_map; string test_string = "aa=bb,cc,dd=,ee=ff=gg"; - impl::parse_xla_backend_extra_options(&test_map, test_string); + parse_xla_backend_extra_options(&test_map, test_string); EXPECT_EQ(test_map.size(), 4); EXPECT_EQ(test_map.at("aa"), "bb"); EXPECT_EQ(test_map.at("cc"), ""); @@ -41,7 +40,7 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -56,7 +55,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot;"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -71,7 +70,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -84,7 +83,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -96,7 +95,6 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { diff --git a/tensorflow/compiler/xla/execution_options_util.cc b/tensorflow/compiler/xla/execution_options_util.cc index e83ff7cddd675197c7f6d7018257edb4c25b6228..cf569863bbe1c92bdcafb133d49dcf5ae8890ffe 100644 --- a/tensorflow/compiler/xla/execution_options_util.cc +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" namespace xla { ExecutionOptions CreateDefaultExecutionOptions() { ExecutionOptions execution_options; - *(execution_options.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); + *(execution_options.mutable_debug_options()) = GetDebugOptionsFromFlags(); return execution_options; } diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index fb135f5ceda67ce6c001de15b8f3f084ca164826..1fea816a803bfb75b9721393cef8c4dfc249268d 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.core.framework import attr_value_pb2 @@ -64,22 +61,18 @@ class Sharding(object): tile_assignment_devices=[core])) @classmethod - def tile(cls, tile_shape, tile_assignment): + def tile(cls, tile_assignment): """Returns a Tiled sharding attribute. This causes an op to be partially computed on multiple cores in the XLA device. Args: - tile_shape: A xla_shape.Shape describing the tile shape that each core - will compute. - The tile shape does not need to be divisible by the tile assignment. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: - TypeError: tile_assignment was not of np.array type or tile_shape was - not of xla_shape.Shape type. + TypeError: tile_assignment was not of np.array type. TODO(jmolloy): This concept is nefarious and is not something we really want to expose to users (especially as the @@ -87,14 +80,11 @@ class Sharding(object): """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') - if not isinstance(tile_shape, xla_shape.Shape): - raise TypeError('Tile shape must be of type xla_shape.Shape') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape.message, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) @@ -118,14 +108,8 @@ class Sharding(object): shape = tensor.shape.as_list() if shape[split_dimension] < num_devices: raise ValueError('Split dimension was smaller than the required number ' - 'of splits: shape=%r, dimension=%r, num_devices=%r', - shape, split_dimension, num_devices) - - tile_shape = shape - tile_shape[split_dimension] = int( - math.ceil(tile_shape[split_dimension] / num_devices)) - tile_shape_proto = xla_data_pb2.Shape( - element_type=xla_data_pb2.F32, dimensions=tile_shape) + 'of splits: shape=%r, dimension=%r, num_devices=%r' % + (shape, split_dimension, num_devices)) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices @@ -133,7 +117,6 @@ class Sharding(object): return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) @@ -149,7 +132,6 @@ class Sharding(object): type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto - attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. @@ -194,8 +176,8 @@ def assign_device(tensor, device): return tensor -def tile(tensor, tile_shape, tile_assignment): - Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment): + Sharding.tile(tile_assignment).apply_to_tensor(tensor) return tensor diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index bcfbcc3a22f50c748c388d17fbcd7defd27846d0..12b7094705e75305dc43a013576f4549dd5f4185 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -3,15 +3,15 @@ upper_tabs: - include: /_upper_tabs_left.yaml - include: /api_docs/_upper_tabs_api.yaml # Dropdown menu -- name: Ecosystem - path: /ecosystem +- name: Resources + path: /resources is_default: true menu: - - include: /ecosystem/_menu_toc.yaml + - include: /resources/_menu_toc.yaml lower_tabs: # Subsite tabs other: - - name: Guide + - name: Guide & Tutorials contents: - title: XLA overview path: /xla/overview @@ -27,3 +27,7 @@ upper_tabs: path: /xla/shapes - title: Using AOT compilation path: /xla/tfcompile + - heading: Tutorials + - title: XLA compile API + path: /xla/tutorials/xla_compile + status: experimental diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml index 7934cd11ba22d3f47e172726f54ce51d15eb2cad..858de427119bfcfa82d0b1158776bf269129fd92 100644 --- a/tensorflow/compiler/xla/g3doc/_index.yaml +++ b/tensorflow/compiler/xla/g3doc/_index.yaml @@ -17,7 +17,7 @@ landing_page: - classname: devsite-landing-row-cards items: - heading: XLA - TensorFlow, compiled - image_path: /ecosystem/images/tf-logo-card-16x9.png + image_path: /resources/images/tf-logo-card-16x9.png path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html buttons: - label: Read on Google Developers blog @@ -28,7 +28,7 @@ landing_page: - label: Watch the video path: https://www.youtube.com/watch?v=kAOanJczHA0 - heading: XLA on GitHub - image_path: /ecosystem/images/github-card-16x9.png + image_path: /resources/images/github-card-16x9.png path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla buttons: - label: View on GitHub diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index 5376a04669d7c17a2fed8cdab46e21277049bf72..85fa16ccc7f48a3dce840564e79097c9e136767f 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -58,7 +58,7 @@ sess = tf.Session(config=config) > compiled for the CPU. JIT compilation for CPU operations must be done via > the manual method documented below. -#### Manual +#### Manual with experimental_jit_scope() JIT compilation can also be turned on manually for one or more operators. This is done by tagging the operators to compile with the attribute @@ -79,6 +79,16 @@ The `_XlaCompile` attribute is currently supported on a best-effort basis. If an operator cannot be compiled, TensorFlow will silently fall back to the normal implementation. +#### Manual with xla.compile() + +Unlike experimental_jit_scope() which silently falls back to normal Tensorflow +on uncompilable operator, xla.compile() returns an explicit error. This is +useful if you want more predictable behaviors from XLA compilation. + +Please see +[xla.compile() tutorial Colab](./tutorials/xla_compile.ipynb) +for how to use it. + ### Placing operators on XLA devices Another way to run computations via XLA is to place an operator on a specific @@ -134,7 +144,7 @@ Execute the python script to train the model with XLA and turn on a debugging feature of XLA via an environmental variable that outputs the XLA graph. ```shell -TF_XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py ``` Open the timeline file created (`timeline.ctf.json`). The rendered timeline diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index a3cdfe19b2e3470bb903cce6cbc79d8d13cc8349..73a9db75f6bf090bba5c3534f14d8ebfa421b5bb 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1339,6 +1339,22 @@ the semantics for `tf.gather_nd`. index `X` in the gather indices array picks an entire row and the result is the concatenation of all these rows. +## GetDimensionSize + +See also +[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the size of the given dimension of the operand. The operand must be +array shaped. + + `GetDimensionSize(operand, dimension)` + +| Arguments | Type | Semantics | +| ----------- | ------- | --------------------------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | +: : : dimension : + ## GetTupleElement See also diff --git a/tensorflow/compiler/xla/g3doc/overview.md b/tensorflow/compiler/xla/g3doc/overview.md index 6a172c3ae159974fb4a34ec422a9a96079b0814a..d3428b7276131e8f406f60cfea9a9346c5478433 100644 --- a/tensorflow/compiler/xla/g3doc/overview.md +++ b/tensorflow/compiler/xla/g3doc/overview.md @@ -4,11 +4,8 @@ -> Note: XLA is experimental and considered alpha. Most use cases will not -> see improvements in performance (speed or decreased memory usage). We have -> released XLA early so the Open Source Community can contribute to its -> development, as well as create a path for integration with hardware -> accelerators. +> Note: XLA is still under development. Some use cases will not +> see improvements in speed or decreased memory usage. XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes TensorFlow computations. The results are improvements in diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2a83092805be5efdd7b9ab54449b2bcc6a2ec481 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -0,0 +1,373 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "The XLA compile API", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "colab_type": "text", + "id": "f4TSNCvpENrW" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "cellView": "form", + "colab_type": "code", + "id": "vamNSA0vEP-m", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "e1oSi4lHFt3z" + }, + "cell_type": "markdown", + "source": [ + "# The XLA compile API" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "cell_type": "markdown", + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "v9YbsuLZaBXy" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "45kUPj5ZFrRa", + "colab": {} + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "\n", + "from tensorflow.contrib.compiler import xla" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "GZVNiRmTDV-5" + }, + "cell_type": "markdown", + "source": [ + "Define some necessary constants and prepare the MNIST dataset." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "f37TSEGvGX4_", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "code", + "id": "TiVXchblG5hK", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n", + "\n", + "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n", + "images, labels = iterator.get_next()\n", + "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", + "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "x_ZehpZP-SfS" + }, + "cell_type": "markdown", + "source": [ + "# Define the model constructing function\n", + "\n", + "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", + "\n", + "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "ZbhJl_WvGa3g", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def build_mnist_model(x, y_):\n", + " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", + "\n", + " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n", + " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", + "\n", + " return y, train_step" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "7Jh3lyQHDfM9" + }, + "cell_type": "markdown", + "source": [ + "# Enable XLA\n", + "\n", + "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "kYpCXCdRHNuN", + "colab": {} + }, + "cell_type": "code", + "source": [ + "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "4giQh62IrZGF" + }, + "cell_type": "markdown", + "source": [ + "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", + "\n", + "xla.compile does not return any\n", + "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n", + "\n", + "In pseudo-code, xla.compile's implementation looks as follows:\n", + "\n", + "---\n", + "```\n", + "# Ask Tensorflow to execute code in XLA-friendly manner\n", + "\n", + "y, train_step = build_mnist_model(images, labels)\n", + "with tf.control_dependencies([train_step]):\n", + " y = tf.identity(y)\n", + "\n", + "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n", + "```\n", + "---\n", + "\n", + "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "TPGas4jjFLZl" + }, + "cell_type": "markdown", + "source": [ + "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "EZD1m_n1DxAF" + }, + "cell_type": "markdown", + "source": [ + "# Train and test the model" + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "qe28bAHNHUG2", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Creates session and initialize all variables.\n", + "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", + "sess = tf.Session()\n", + "sess.run(tf.global_variables_initializer())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "qgsKmz3n2UiW" + }, + "cell_type": "markdown", + "source": [ + "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "_GxF6jTRHVuA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132" + }, + "cell_type": "code", + "source": [ + "# Feeds training dataset\n", + "sess.run(iterator.make_initializer(train_ds))\n", + "\n", + "# Runs TRAIN_STEPS steps\n", + "for i in range(TRAIN_STEPS):\n", + " sess.run(y)\n", + "\n", + "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Model trained for 1000 steps.\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "dHlQlRSRHXD1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093" + }, + "cell_type": "code", + "source": [ + "# Tests trained model\n", + "\n", + "# Feeds testing dataset\n", + "sess.run(iterator.make_initializer(test_ds))\n", + "\n", + "# Calculates accuracy\n", + "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prediction accuracy after training: 0.91\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "ynJQIuzjHYOb", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Cleans up session\n", + "sess.close()" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 93522d2ca87a7eba8d3c7533785c54e63ce507b0..fa94d0afb4c9280b8f8fa9642c1b0ab7285ee6f3 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -24,8 +24,7 @@ limitations under the License. namespace xla { namespace { -void SetMinorToMajorLayout(Shape* shape, - std::initializer_list dimensions) { +void SetMinorToMajorLayout(Shape* shape, std::vector dimensions) { shape->mutable_layout()->clear_minor_to_major(); for (auto dimension : dimensions) { shape->mutable_layout()->add_minor_to_major(dimension); @@ -122,7 +121,7 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { std::vector linear_indexes = {0, 1439999999, 1145567336, 43883404, 617295214, 1117613654}; - std::vector> minor_to_major_orders; + std::vector> minor_to_major_orders; minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0}); minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6}); minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3}); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD deleted file mode 100644 index 3e79129aafd234e5eab05d205f2017b54057795e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ /dev/null @@ -1,82 +0,0 @@ -# Legacy command-line flags for the XLA libraries. - -# Please do not add more flags to this package. - -# The XLA libraries were written in an environment that allowed command-line -# flags to be scattered freely throughout the libraries. This model, while -# initially convenient, leads to a proliferation in unused command-line flags -# in tests and binaries, and serious problems in servers, where one might wish -# parameters to be different in independent RPC calls to the same routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "parse_flags_from_env", - srcs = ["parse_flags_from_env.cc"], - hdrs = ["parse_flags_from_env.h"], - deps = - [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "parse_flags_from_env_test", - srcs = ["parse_flags_from_env_test.cc"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "debug_options_flags", - srcs = [ - "debug_options_flags.cc", - "debug_options_parsers.h", - ], - hdrs = ["debug_options_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "debug_options_parsers_test", - size = "small", - srcs = [ - "debug_options_parsers.h", - "debug_options_parsers_test.cc", - ], - deps = - [ - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h deleted file mode 100644 index b54482ad2ba2224c781861341a80ceb878ffd343..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ - -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from the environtment variable TF_XLA_FLAGS, or (if the first -// non-whitespace in the variable value is not '-'), a file named by that -// environment variable. The accepted syntax is that flags arguments are of -// the form --flag=value or (for boolean flags) --flag, and are whitespace -// separated. The may be one of: -// - -// in which case the effective value is the string itself -// - in which case the effective value is the -// string with the single-quotes removed -// - in which case the effective value if the -// string with the double-quotes removed, and escaped sequences of -// replaced by . -// -// Flags values inconsistent with the type of the flag will be rejected by the -// flag parser. -// -// Examples: -// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" -// -// TF_XLA_FLAGS=/tmp/flagfile -// where /tmp/flagfile might contain -// --some_flag="This is a string containing a \" and a '." -// --another_flag=wombats - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet -// unrecognized flags passed in from the environment, and return its -// return value. -bool ParseFlagsFromEnv(const std::vector& flag_list); - -// Used only for testing. Not to be used by clients. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 510aa39b4503111ec558c050f0c332c93de10517..36ad7c64866e77187d40f22b364d80230651696b 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -1012,167 +1013,143 @@ void LiteralBase::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); +string ShapeToString(bool print_layout, const Shape& shape) { + return print_layout ? ShapeUtil::HumanStringWithLayout(shape) + : ShapeUtil::HumanString(shape); +} - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces); - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); +void TupleToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); + } + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); + pieces->push_back("\n)"); +} + +void SparseArrayToStringHelper(const LiteralBase& literal, + const Shape& subshape, bool print_layout, + std::vector* pieces) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); } - pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { - pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); } - pieces->push_back("}"); - return; + pieces->push_back(literal.GetSparseElementAsString(i)); } + pieces->push_back("}"); +} - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = [&](absl::Span indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; +void DenseArrayToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + int64 rank = ShapeUtil::Rank(subshape); + + std::function dimensions, std::vector*)> + to_string_recursive = [&](absl::Span dimensions, + std::vector* accum_indices) { + // dimensions.size() decreases by 1 at each recursive call, + // and accum_indices->size() increases by 1. + // Their sum is equal to the rank of the tensor. + CHECK_EQ(rank, dimensions.size() + accum_indices->size()); + + auto brace_to_string = [&](string brace) -> string { + // Handle 1D tensor + if (rank == 1) { + return brace; + } + // Handle the innermost tensor of a 2D+ tensor. + if (dimensions.size() == 1 && brace == "{") { + return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " "); + } + if (dimensions.size() == 1 && brace == "}") { + return StrCat(dimensions[0] <= 1 ? "" : " ", brace); + } + // Handle the non-innermost tensors of a 2D+ tensor. + if (brace == "{") { + if (rank > 3 && !accum_indices->empty() && + accum_indices->size() < rank) { + int index = accum_indices->size() - 1; + int value = accum_indices->back(); + return StrCat(brace, " /*i", index, "=", value, "*/\n"); + } + return StrCat(brace, "\n"); + } + return StrCat("\n", brace); + }; - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); + if (dimensions.empty()) { + // Display predicates as 0s and 1s so that the string is more dense. + string elem; + if (subshape.element_type() == PRED && rank > 0) { + elem = literal.Get(*accum_indices, shape_index) ? "1" : "0"; + } else { + elem = literal.GetAsString(*accum_indices, shape_index); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + pieces->push_back(elem); + } else { + pieces->push_back(brace_to_string("{")); + for (int i = 0; i < dimensions[0]; ++i) { + std::vector cloned_indices(*accum_indices); + cloned_indices.push_back(i); + to_string_recursive(dimensions.subspan(1), &cloned_indices); + if (i < dimensions[0] - 1) { + pieces->push_back(","); + pieces->push_back(dimensions.size() > 1 ? "\n" : " "); } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); + pieces->push_back(brace_to_string("}")); } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); + }; + + if (rank > 1) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" "); + } + std::vector indices = {}; + std::vector dimensions(subshape.dimensions().begin(), + subshape.dimensions().end()); + to_string_recursive(dimensions, &indices); +} + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + if (ShapeUtil::IsTuple(subshape)) { + TupleToStringHelper(literal, shape_index, print_layout, pieces); + } else if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + } else if (LayoutUtil::IsSparseArray(subshape)) { + SparseArrayToStringHelper(literal, subshape, print_layout, pieces); } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](absl::Span indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); + CHECK(LayoutUtil::IsDenseArray(subshape)); + DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); } } @@ -1435,10 +1412,14 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case U8: return EqualElementsInternal(other, &multi_index); + case S16: + return EqualElementsInternal(other, &multi_index); case S32: return EqualElementsInternal(other, &multi_index); case S64: return EqualElementsInternal(other, &multi_index); + case U16: + return EqualElementsInternal(other, &multi_index); case U32: return EqualElementsInternal(other, &multi_index); case U64: @@ -1507,6 +1488,11 @@ bool LiteralBase::IsAll(int8 value) const { return AllElementsEqualValue(piece.data(), value); } return false; + case U16: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; case U32: if (value >= 0) { return AllElementsEqualValue(piece.data(), value); @@ -1519,6 +1505,8 @@ bool LiteralBase::IsAll(int8 value) const { return false; case S8: return AllElementsEqualValue(piece.data(), value); + case S16: + return AllElementsEqualValue(piece.data(), value); case S32: return AllElementsEqualValue(piece.data(), value); case S64: @@ -1740,12 +1728,16 @@ bool LiteralBase::IsZero(absl::Span indices) const { switch (shape().element_type()) { case U8: return Get(indices) == 0; + case U16: + return Get(indices) == 0; case U32: return Get(indices) == 0; case U64: return Get(indices) == 0; case S8: return Get(indices) == 0; + case S16: + return Get(indices) == 0; case S32: return Get(indices) == 0; case S64: @@ -1803,6 +1795,20 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case U16: + *proto->mutable_u16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_u16s()); + } + break; + case S16: + *proto->mutable_s16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_s16s()); + } + break; case F16: *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); @@ -1917,6 +1923,22 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case S16: { + const string& s(proto.s16s()); + TF_RET_CHECK(data().size() * sizeof(int16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case U16: { + const string& s(proto.u16s()); + TF_RET_CHECK(data().size() * sizeof(uint16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; case F16: { const string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index e791048b4d9f5dcf877e05e3b5cf16eb37c07dbc..fa9a71af4ceb998a7a289443cbef70eb52cb1a11 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -301,7 +301,7 @@ class LiteralBase { // // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero - // initialization, then reinitialization. Conside if a call to + // initialization, then reinitialization. Consider if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static Literal CreateFromShape(const Shape& shape); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 9d34d9d504156c4b9e645ccfa7cdbd346e51390b..b044f0ad73f13a0599e77f1f43888bc974e31f73 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -141,8 +141,10 @@ int64 RecursiveElementCount(const Shape& shape) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else { + } else if (ShapeUtil::IsArray(shape)) { return ShapeUtil::ElementsIn(shape); + } else { + return 0; } } diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 4ae5ddbfdb8444ac778f82d01b1066aad8c0aa78..bd93517728b052aed854df0f9d9c5447bc3b156f 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec.ToString()); + EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -150,12 +150,58 @@ TEST_F(LiteralUtilTest, R3ToString) { const auto literal = LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { -{ { 1 }, - { 2 } }, -{ { 3 }, - { 4 } }, -{ { 5 }, - { 6 } } +{ + {1}, + {2} +}, +{ + {3}, + {4} +}, +{ + {5}, + {6} +} +})"; + EXPECT_EQ(expected, literal.ToString()); +} + +TEST_F(LiteralUtilTest, R6ToString) { + const auto literal = + LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2}); + const string expected = R"(s32[2,2,1,1,1,2] { +{ /*i0=0*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +}, +{ /*i0=1*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +} })"; EXPECT_EQ(expected, literal.ToString()); } @@ -190,12 +236,16 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[2,3,2] { -{ { 1, 2 }, +{ + { 1, 2 }, { 3, 4 }, - { 5, 6 } }, -{ { 7, 8 }, + { 5, 6 } +}, +{ + { 7, 8 }, { 9, 10 }, - { 11, 12 } } + { 11, 12 } +} })"; EXPECT_EQ(expected, result); } @@ -247,18 +297,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - }, - { /*i1=1*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +}, +{ /*i1=1*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +} +} })"; EXPECT_EQ(expected, result); } @@ -268,30 +318,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2, 3}, - {4, 5, 6}, - {7, 8, 9} - }, - { /*i1=1*/ - {11, 12, 13}, - {14, 15, 16}, - {17, 18, 19} - } - }, - { /*i0=1*/ - { /*i1=0*/ - {101, 102, 103}, - {104, 105, 106}, - {107, 108, 109} - }, - { /*i1=1*/ - {201, 202, 203}, - {204, 205, 206}, - {207, 208, 209} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } +}, +{ /*i1=1*/ + { 11, 12, 13 }, + { 14, 15, 16 }, + { 17, 18, 19 } +} +}, +{ /*i0=1*/ +{ /*i1=0*/ + { 101, 102, 103 }, + { 104, 105, 106 }, + { 107, 108, 109 } +}, +{ /*i1=1*/ + { 201, 202, 203 }, + { 204, 205, 206 }, + { 207, 208, 209 } +} +} })"; EXPECT_EQ(expected, result); } @@ -1394,6 +1444,28 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { EXPECT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, CopyFromProto_u16) { + uint16 u1(0xabcd); + uint16 u2(0x1234); + + const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12, + 0x34, 0x12, 0xcd, 0xab}; + LiteralProto p; + p.mutable_shape()->set_element_type(U16); + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(4); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + p.clear_u16s(); + p.set_u16s(uint16_vals, 8); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); + ASSERT_EQ(4, r.size()); + EXPECT_EQ(u1, r[0]); + EXPECT_EQ(u2, r[1]); + EXPECT_EQ(u2, r[2]); + EXPECT_EQ(u1, r[3]); +} + TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1515,9 +1587,9 @@ TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nested_tuple = LiteralUtil::MakeTuple( {&tuple_elements[0], &tuple_elements[1], &nil_literal}); - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); std::vector elements = nested_tuple.DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1568,7 +1640,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); for (const Literal& element : elements) { - EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape())); } } diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc similarity index 65% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc rename to tensorflow/compiler/xla/parse_flags_from_env.cc index 2a4e49b05aa0d1eed2197095694cfc6aa8814983..5b568888d14f21c1330556d017eafba6c8dd2228 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -13,16 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from an environtment variable, or a file named by the environment -// variable. +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or a file named by the +// environment variable. #include #include #include +#include +#include #include -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -31,9 +36,7 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { -static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried static const char kWS[] = " \t\r\n"; // whitespace // The following struct represents an argv[]-style array, parsed @@ -43,12 +46,20 @@ static const char kWS[] = " \t\r\n"; // whitespace // constructor/destructor collisions with other "private" types // in the same named namespace. namespace { + +// Functor which deletes objects by calling `free`. Necessary to free strdup'ed +// strings created by AppendToEnvArgv. +struct FreeDeleter { + void operator()(char* ptr) { free(ptr); } +}; + struct EnvArgv { EnvArgv() : initialized(false), argc(0) {} bool initialized; // whether the other fields have been set. int argc; // elements used in argv[] std::vector argv; // flag arguments parsed from environment string. - std::vector argv_save; // saved values from argv[] to avoid leaks + // saved values from argv[] to avoid leaks + std::vector> argv_save; }; } // anonymous namespace @@ -64,7 +75,7 @@ static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, string s = string(s0, s0len) + string(s1, s1len); char* str = strdup(s.c_str()); a->argv.push_back(str); - a->argv_save.push_back(str); + a->argv_save.emplace_back(str); a->argc++; } } @@ -128,14 +139,14 @@ static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { } } -// Call ParseArgvFromString(..., a) on a string derived from the setting of an -// environment variable kEnvVar, or a file it points to. -static void SetArgvFromEnv(EnvArgv* a) { +// Call ParseArgvFromString(..., a) on a string derived from the setting of the +// environment variable `envvar`, or a file it points to. +static void SetArgvFromEnv(absl::string_view envvar, EnvArgv* a) { if (!a->initialized) { static const char kDummyArgv[] = ""; AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, a); // dummy argv[0] - const char* env = getenv(kEnvVar); + const char* env = getenv(string(envvar).c_str()); if (env == nullptr || env[0] == '\0') { // nothing } else if (env[strspn(env, kWS)] == '-') { // flags in env var value @@ -158,49 +169,66 @@ static void SetArgvFromEnv(EnvArgv* a) { } } -// The simulated argv[] parsed from the environment. -static EnvArgv* env_argv; +// The simulated argv[] parsed from the environment, one for each different +// environment variable we've seen. +static std::unordered_map& EnvArgvs() { + static auto* env_argvs = new std::unordered_map(); + return *env_argvs; +} -// Used to protect accesses to env_argv. +// Used to protect accesses to env_argvs. static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); -// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized -// flags passed in from the environment. -bool ParseFlagsFromEnv(const std::vector& flag_list) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - SetArgvFromEnv(env_argv); // a no-op if already initialized +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list) { + tensorflow::mutex_lock lock(env_argv_mu); + auto* env_argv = &EnvArgvs()[string(envvar)]; + SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized bool result = tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); - env_argv_mu.unlock(); + + // There's always at least one unparsed argc, namely the fake argv[0]. + if (result && env_argv->argc != 1) { + // Skip the first argv, which is the fake argv[0]. + auto unknown_flags = absl::MakeSpan(env_argv->argv); + unknown_flags.remove_prefix(1); + + // Some flags are set on XLA_FLAGS, others on TF_XLA_FLAGS. If we find an + // unrecognized flag, suggest the alternative. + string alternate_envvar; + if (envvar == "TF_XLA_FLAGS") { + alternate_envvar = "XLA_FLAGS"; + } else if (envvar == "XLA_FLAGS") { + alternate_envvar = "TF_XLA_FLAGS"; + } + string did_you_mean; + if (!alternate_envvar.empty()) { + did_you_mean = absl::StrFormat( + "\nPerhaps you meant to specify these on the %s envvar?", + alternate_envvar); + } + + LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; + return false; + } return result; } // Testing only. -// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv() -// will parse the environment variable (or the file it points to) anew, and set -// *pargc, and *pargv to point to the internal locations of the argc and argv -// constructed from the environment. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - if (!env_argv->argv_save.empty()) { - for (int i = 0; env_argv->argv_save[i] != nullptr; i++) { - free(env_argv->argv_save[i]); - } - } - env_argv->initialized = false; - env_argv->argc = 0; - env_argv->argv.clear(); - env_argv->argv_save.clear(); - env_argv_mu.unlock(); - *pargc = &env_argv->argc; - *pargv = &env_argv->argv; +// +// Resets the env_argv struct so that subsequent calls to +// ParseFlagsFromEnvAndDieIfUnknown() will parse the environment variable (or +// the file it points to) anew, and set *pargc, and *pargv to point to the +// internal locations of the argc and argv constructed from the environment. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv) { + tensorflow::mutex_lock lock(env_argv_mu); + EnvArgvs().erase(string(envvar)); + auto& env_argv = EnvArgvs()[string(envvar)]; + *pargc = &env_argv.argc; + *pargv = &env_argv.argv; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h new file mode 100644 index 0000000000000000000000000000000000000000..76940a4299ac50138222333ff250a264cc941288 --- /dev/null +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ +#define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ + +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or (if the first +// non-whitespace in the variable value is not '-'), a file named by that +// environment variable. +// +// The accepted syntax is that flags arguments are of the form --flag=value or +// (for boolean flags) --flag, and are whitespace separated. The may be +// one of: +// +// - +// in which case the effective value is the string itself +// - in which case the effective value is the +// string with the single-quotes removed +// - in which case the effective value if the +// string with the double-quotes removed, and escaped sequences of +// replaced by . +// +// Flags values inconsistent with the type of the flag will be rejected by the +// flag parser. +// +// Examples: +// +// - TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" +// - TF_XLA_FLAGS=/tmp/flagfile +// +// where /tmp/flagfile might contain +// +// --some_flag="This is a string containing a \" and a '." +// --another_flag=wombats + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { + +// Calls tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet +// unrecognized flags passed in the environment variable `envvar`, and returns +// its return value. +// +// Raises a fatal error if any flags in `envvar` were not recognized. +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list); + +// Used only for testing. Not to be used by clients. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc similarity index 89% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc rename to tensorflow/compiler/xla/parse_flags_from_env_test.cc index 138c0c852e2bb0527d171f25b4d96cedc5671516..3465552ebbf52140fb954b247d99d3c6afe7fcde 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include #include @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Test that XLA flags can be set from the environment. // Failure messages are accompanied by the text in msg[]. @@ -38,20 +37,7 @@ static void TestParseFlagsFromEnv(const char* msg) { // Initialize module under test. int* pargc; std::vector* pargv; - ResetFlagsFromEnvForTesting(&pargc, &pargv); - - // Ensure that environment variable can be parsed when - // no flags are expected. - std::vector empty_flag_list; - bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); - CHECK(parsed_ok) << msg; - const std::vector& argv_first = *pargv; - CHECK_NE(argv_first[0], nullptr) << msg; - int i = 0; - while (argv_first[i] != nullptr) { - i++; - } - CHECK_EQ(i, *pargc) << msg; + ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv); // Check that actual flags can be parsed. bool simple = false; @@ -66,7 +52,7 @@ static void TestParseFlagsFromEnv(const char* msg) { tensorflow::Flag("single_quoted", &single_quoted, ""), tensorflow::Flag("double_quoted", &double_quoted, ""), }; - parsed_ok = ParseFlagsFromEnv(flag_list); + bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); CHECK_EQ(*pargc, 1) << msg; const std::vector& argv_second = *pargv; CHECK_NE(argv_second[0], nullptr) << msg; @@ -159,12 +145,11 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) { } } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { // Save name of binary so that it may invoke itself. - xla::legacy_flags::binary_name = argv[0]; + xla::binary_name = argv[0]; bool recursing = false; xla::int32 int_flag = 1; const std::vector flag_list = { @@ -173,7 +158,8 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); + bool parse_ok = + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 21685c4a5b90f76440e4cf10cce004b6cf925cc8..63ac1c6649210cbae9e238a74e0a45fb8ee4da63 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") py_library( name = "xla_client", @@ -66,6 +67,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -81,6 +83,7 @@ tf_py_wrap_cc( srcs = ["xla.i"], swig_includes = [ "local_computation_builder.i", + "//tensorflow/python:platform/base.i", ], deps = [ ":local_computation_builder", @@ -89,5 +92,7 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ], + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service:gpu_plugin", + ]), ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b1fae826ab1903fb73541a7ae32b5cc57b3b92a7..4d2a37cfac3e0e89d189f168031e5db44ca5d410 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -56,6 +57,12 @@ tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; +string* GetPlatformNameString() { + static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = + new string("Host"); + return platform_name_string; +} + Status InitializeReplicaCount(int replica_count) { if (replica_count < 1) { return InvalidArgument("Replica count must be >= 1; got %d.", @@ -72,17 +79,33 @@ Status InitializeReplicaCount(int replica_count) { return Status::OK(); } +Status InitializePlatformName(const string& platform_name) { + string* g_platform_name = GetPlatformNameString(); + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the platform name to %s, but a local XLA service was " + "previously created with a platform name of %s.", + platform_name, *g_platform_name); + } + TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); + *g_platform_name = platform_name; + return Status::OK(); +} + int GetReplicaCount() { tensorflow::mutex_lock lock(g_local_client_mutex); return g_replica_count; } LocalClient* GetOrCreateLocalClient() { + string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { return g_local_client; } LocalClientOptions options; + options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); CHECK(g_local_client != nullptr); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 82f84ddb35bd4455fd3607509c6329457cca47f3..9e617c48bdc5ae4b37c1a1db9a1876bb4c0a6f0d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -39,6 +39,12 @@ namespace swig { // returned. Status InitializeReplicaCount(int replica_count); +// Initializes the platform name that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializePlatformName(const string& platform_name); + // Returns the replica count that is currently set, regardless of whether the // local XLA service has been instantiated yet or not. int GetReplicaCount(); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index c13d00d2530c7e9321d483a70e4a12361159362d..feabfdb889ca055550c5d1e1c05ca47c1b0bd166 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -977,6 +977,7 @@ tensorflow::ImportNumpy(); %unignore xla; %unignore xla::swig; %unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 07e0e093255b2baf3412852821fe62fa060f6cad..92b0685dbba195405d78867776fe43b5f6c60f4c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1371,6 +1371,18 @@ def initialize_replica_count(replica_count): c_api.InitializeReplicaCount(replica_count) +def initialize_platform_name(platform_name): + """Initializes the desired platform name to use on XLA service init. + + Args: + platform_name: string name of platform. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + c_api.InitializePlatformName(platform_name) + + def get_replica_count(): """Returns the current replica count used for the XLA service. diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 4e1435fa30a24c320ddbedb84d37b369a3158a54..d8123a6de28ca532819ece4a75cd0b725f8c1bbd 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,11 +47,18 @@ namespace xla { }); } -::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, + const CompileRequest* arg, + CompileResponse* result) { return DelegateRPC( - [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); + [this, arg, result]() { return service_->Compile(arg, result); }); +} + +::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/, + const ExecuteRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Execute(arg, result); }); } ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index ca1b09b648013ad45d806040c5ddcf11d9e5604e..3e586b288a56a22573d0c3b9ae7b2f25fdbf851a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, + const CompileRequest* arg, + CompileResponse* result) override; + + ::grpc::Status Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 7b8ab158e1396d7087a407be180ab44d2e16e121..66abf66cfd6c2f753c5507aa373452ac880e9a29 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Compile(const CompileRequest* request, + CompileResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteGraph(context, *request, response); + return grpc_stub_->Compile(context, *request, response); + }); +} + +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Execute(context, *request, response); }); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 8dfcb761387d608abbb1f62974f49b976a7ff7ff..f02b401399f3e895153f0b08e325bc9c2c2336ec 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status Compile(const CompileRequest* request, + CompileResponse* response) override; + + Status Execute(const ExecuteRequest* request, + ExecuteResponse* response) override; Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 551ae895e05586daec0ffcd425f4950f76bdd50d..e4f332cda22cc5b889bf73f06913b96d6091dc81 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -128,11 +128,14 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. + // Compiles the provided computation into executable. Returns the handle of + // the executable. + rpc Compile(CompileRequest) returns (CompileResponse) {} + + // Invokes the provided executable with the provided global data passed as + // immutable arguments. The request contains the handle to the executable. // Returns global data output and execution timing. - rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { - } + rpc Execute(ExecuteRequest) returns (ExecuteResponse) {} // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0097e917c869b19908ea11b3a647ecc9bad12dc7..1bd04d2785913c59929478974883b9669e1c1185 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -124,7 +123,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -158,12 +156,12 @@ tf_cc_test( ":bfloat16_propagation", ":bfloat16_support", ":hlo", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -281,7 +279,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -294,6 +292,7 @@ cc_library( name = "hlo", srcs = [ "dfs_hlo_visitor.cc", + "dynamic_parameter_binding.cc", "hlo_computation.cc", "hlo_input_output_alias_config.cc", "hlo_instruction.cc", @@ -307,6 +306,7 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "dynamic_parameter_binding.h", "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", @@ -323,7 +323,6 @@ cc_library( ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", - ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal", @@ -353,6 +352,25 @@ cc_library( ], ) +tf_cc_test( + name = "dynamic_parameter_binding_test", + srcs = ["dynamic_parameter_binding_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], @@ -365,7 +383,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,10 +419,12 @@ cc_library( srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], @@ -420,7 +439,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -466,7 +484,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -519,7 +536,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -568,7 +584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -591,7 +606,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -603,11 +617,11 @@ cc_library( hdrs = ["platform_util.h"], deps = [ ":compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -647,6 +661,7 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", + ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -662,6 +677,7 @@ cc_library( ":source_map_util", ":stream_pool", ":transfer_manager", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -673,7 +689,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", @@ -730,12 +745,12 @@ cc_library( ":computation_layout", ":platform_util", ":service", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -811,6 +826,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", ], @@ -833,6 +849,7 @@ cc_library( ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", @@ -840,7 +857,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -1086,7 +1102,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1103,6 +1118,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_reachability", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1168,7 +1184,6 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1343,6 +1358,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1362,6 +1378,7 @@ cc_library( ":fusion_queue", ":hlo", ":hlo_pass", + ":hlo_reachability", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", @@ -1387,6 +1404,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", @@ -1427,7 +1445,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1503,7 +1520,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1555,7 +1571,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1592,7 +1607,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1642,7 +1656,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1694,6 +1708,19 @@ cc_library( ], ) +tf_cc_test( + name = "while_loop_analysis_test", + srcs = ["while_loop_analysis_test.cc"], + deps = [ + ":while_loop_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1702,9 +1729,11 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", + ":hlo_query", + ":pattern_matcher", ":while_loop_analysis", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1716,10 +1745,17 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":algebraic_simplifier", + ":hlo", + ":hlo_cse", + ":hlo_dce", ":hlo_matchers", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -1750,7 +1786,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1778,7 +1814,7 @@ tf_cc_test( ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1823,7 +1859,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1857,7 +1892,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -2263,7 +2298,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2326,13 +2360,27 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], ) +cc_library( + name = "compilation_cache", + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], + deps = [ + ":executable", + ":hlo_module_config", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "layout_assignment", srcs = [ @@ -2402,14 +2450,13 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2527,7 +2574,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2594,7 +2640,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2656,7 +2701,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2697,7 +2742,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2736,7 +2780,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2808,10 +2851,9 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -2844,6 +2886,46 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_get_dimension_size_rewriter", + srcs = ["hlo_get_dimension_size_rewriter.cc"], + hdrs = ["hlo_get_dimension_size_rewriter.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "hlo_get_dimension_size_rewriter_test", + srcs = ["hlo_get_dimension_size_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_get_dimension_size_rewriter", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "device_memory_allocator", srcs = [ @@ -2902,6 +2984,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", @@ -2999,7 +3082,6 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -3278,6 +3360,8 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", ], ) @@ -3304,10 +3388,11 @@ cc_library( ":hlo", ":hlo_pass", ":tuple_util", + ":while_loop_analysis", ":while_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3323,7 +3408,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3353,7 +3438,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3366,6 +3451,7 @@ cc_library( ":bfloat16_normalization", ":defuser", ":hlo", + ":hlo_memory_scheduler", ":hlo_pass", ":hlo_pass_pipeline", ":implicit_broadcast_remover", @@ -3413,7 +3499,7 @@ tf_cc_test( ":indexed_array_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", ], @@ -3499,6 +3585,41 @@ cc_library( ], ) +cc_library( + name = "ar_crs_combiner", + srcs = ["ar_crs_combiner.cc"], + hdrs = ["ar_crs_combiner.h"], + deps = [ + ":call_graph", + ":pattern_matcher", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "ar_crs_combiner_test", + srcs = ["ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], @@ -3510,7 +3631,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 85fc42f74756458ee677e8b53448ceb02f08e834..56bf3a9f69d718db1b2845c6901a893a2fe1660b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -83,7 +84,8 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. bool ReshapeOrCopyIsBitcast( const HloInstruction* instr, - const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + const AlgebraicSimplifierOptions::ValidBitcastCallback& + valid_bitcast_callback) { CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); @@ -107,6 +109,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override; + Status HandleAnd(HloInstruction* logical_and) override; + Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcastConvert(HloInstruction* bitcast) override; @@ -141,6 +145,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) override; + Status HandleNegate(HloInstruction* negate) override; + + Status HandleNot(HloInstruction* logical_not) override; + + Status HandleOr(HloInstruction* logical_or) override; + Status HandlePad(HloInstruction* pad) override; Status HandlePower(HloInstruction* power) override; @@ -171,21 +181,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const bool changed() const { return changed_; } // Runs the visitor on a computation. - static bool Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification); + static bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options); private: - explicit AlgebraicSimplifierVisitor( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) - : computation_(computation), - is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + explicit AlgebraicSimplifierVisitor(HloComputation* computation, + const AlgebraicSimplifierOptions& options) + : computation_(computation), options_(options) {} // Transforms Dots where at least one input is a vector or has a degenerate // dimension and converts it into a multiply and reduce. This should enable @@ -224,10 +226,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* new_instruction); // Returns whether the shape of the output of the given instructions are the - // same for the purposes of simplification. If is_layout_sensitive_ is true, - // then this tests shape equality including layout (ShapeUtil::Equal). If - // is_layout_sensitive_ is false, then the tests shape compatibility - // (ShapeUtil::Compatible). + // same for the purposes of simplification. If options_.is_layout_sensitive() + // is true, then this tests shape equality including layout + // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the + // tests shape compatibility (ShapeUtil::Compatible). bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; // Returns whether it was possible to transform `root` to a clamp instruction. @@ -306,30 +308,22 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); - // Tries to simplify a slice(pad(...)) where the result of the slice is a - // scalar. - StatusOr TrySimplifySliceOfPad(HloInstruction* slice); + // Tries to simplify a slice where the result of the slice is a scalar. + StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + + // Tries to convert slice(reshape(X)) into reshape(slice(X)) + StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; + // The backend-specific options selected for the algebraic simplifier. + const AlgebraicSimplifierOptions& options_; + // Whether algebraic simplification has occurred. bool changed_ = false; - // Whether layout is considered during transformation. - bool is_layout_sensitive_; - - // Callback used to determine if a bitcast is possible. - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - - // Disable dot strength reduction on platforms where it causes a slowdown. - bool enable_dot_strength_reduction_; - - // Disable convolution -> dot simplification on platforms where it causes a - // slowdown. - bool enable_conv_simplification_; - // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; }; @@ -337,19 +331,15 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { } // namespace bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) { - AlgebraicSimplifierVisitor visitor( - computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_strength_reduction, enable_conv_simplification); + HloComputation* computation, const AlgebraicSimplifierOptions& options) { + AlgebraicSimplifierVisitor visitor(computation, options); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const { - if (is_layout_sensitive_) { + if (options_.is_layout_sensitive()) { return ShapeUtil::Equal(lhs->shape(), rhs->shape()); } else { return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); @@ -423,6 +413,43 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); + // Simplify logical and + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A && True => A + VLOG(10) << "trying transform [A && True => A]: " + << logical_and->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + // True && A => A + VLOG(10) << "trying transform [True && A => A]: " + << logical_and->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // A && False => False + VLOG(10) << "trying transform [A && False => False]: " + << logical_and->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // False && A => False + VLOG(10) << "trying transform [False && A => False]: " + << logical_and->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. HloInstruction* op; @@ -456,8 +483,8 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { ReplaceWithBitcast(copy); } @@ -1167,7 +1194,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); if (did_strength_reduction) { @@ -1229,6 +1257,64 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { + // negate(negate(x)) => x + HloInstruction* x; + if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && + ReplaceInstructionIfSameShape(negate, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { + // not(not(x)) => x + HloInstruction* x; + if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && + ReplaceInstructionIfSameShape(logical_not, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); + + // Simplify logical or + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A || True => True + VLOG(10) << "trying transform [A || True => True]: " + << logical_or->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + // True || A => True + VLOG(10) << "trying transform [True || A => True]: " + << logical_or->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // A || False => A + VLOG(10) << "trying transform [A || False => A]: " + << logical_or->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // False || A => A + VLOG(10) << "trying transform [False || A => A]: " + << logical_or->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); @@ -1804,8 +1890,8 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -1826,60 +1912,160 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifySliceOfPad( +StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { // Only try to do this for effective scalars. We could do the same for slicing // out larger pieces of padding (replacing with a broadcast of the padding // value), but this is probably not worth it. - if (!ShapeUtil::IsEffectiveScalar(slice->shape()) || - slice->operand(0)->opcode() != HloOpcode::kPad) { + if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { return false; } - VLOG(10) << "Trying to simplify scalar slice of pad"; - // Check there's no internal padding. Again, we could handle that too, since - // everything is statically known, but it's not worth it. - auto pad = Cast(slice->mutable_operand(0)); - auto padding_config = pad->padding_config(); - int64 rank = padding_config.dimensions_size(); - if (HasInteriorPadding(padding_config)) { - VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; - return false; + if (slice->operand(0)->opcode() == HloOpcode::kPad) { + VLOG(10) << "Trying to simplify scalar slice of pad"; + // Check there's no internal padding. Again, we could handle that too, since + // everything is statically known, but it's not worth it. + auto pad = Cast(slice->mutable_operand(0)); + auto padding_config = pad->padding_config(); + int64 rank = padding_config.dimensions_size(); + if (HasInteriorPadding(padding_config)) { + VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; + return false; + } + + // Check whether the scalar we're slicing out falls into the padding. + bool in_padding = [&]() { + for (int64 i = 0; i < rank; ++i) { + int64 start = slice->slice_starts(i); + int64 low = padding_config.dimensions(i).edge_padding_low(); + int64 data = pad->operand(0)->shape().dimensions(i); + if (start >= low && start < low + data) { + return false; + } + } + return true; + }(); + + if (in_padding) { + VLOG(10) << "Folding scalar slice of pad into padding value"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), + pad->mutable_padding_value()))); + return true; + } else { + // We already know the output of the slice is scalar. If the padded + // value is scalar, and it's not in the padding, then it's exactly the + // output value. + bool replaced = + ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + if (replaced) { + VLOG(10) << "Folding scalar slice of pad into padded value"; + } else { + VLOG(10) << "Not folding scalar slice of pad into padded value as they " + "have different shapes."; + } + return replaced; + } } - // Check whether the scalar we're slicing out falls into the padding. - bool in_padding = [&]() { - for (int64 i = 0; i < rank; ++i) { - int64 start = slice->slice_starts(i); - int64 low = padding_config.dimensions(i).edge_padding_low(); - int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { + VLOG(10) << "Trying to simplify scalar slice of concat"; + // Only do this for R1, there's no chance of this being useful otherwise. + if (ShapeUtil::Rank(slice->shape()) != 1) { + VLOG(10) << "Not folding, slice is not rank 1"; + return false; + } + HloConcatenateInstruction* concat = + Cast(slice->mutable_operand(0)); + int64 operand_start = 0; + int64 operand_num = 0; + // Weird loop structure to avoid annoying off-by-one errors. + while (true) { + TF_RET_CHECK(operand_num < concat->operand_count()); + const HloInstruction* operand = concat->operand(operand_num); + int64 next_operand_start = operand_start + operand->shape().dimensions(0); + if (next_operand_start > slice->slice_starts(0)) { + break; } + operand_start = next_operand_start; + operand_num++; } - return true; - }(); - if (in_padding) { - VLOG(10) << "Folding scalar slice of pad into padding value"; - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice, HloInstruction::CreateReshape(slice->shape(), - pad->mutable_padding_value()))); - return true; - } else { - // We already know the output of the slice is scalar. If the padded - // value is scalar, and it's not in the padding, then it's exactly the - // output value. - bool replaced = - ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + bool replaced = ReplaceInstructionIfSameShape( + slice, concat->mutable_operand(operand_num)); if (replaced) { - VLOG(10) << "Folding scalar slice of pad into padded value"; + VLOG(10) << "Folding scalar slice of concat into concat operand"; } else { - VLOG(10) << "Not folding scalar slice of pad into padded value as they " - "have different shapes."; + VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), concat->mutable_operand(operand_num), + {slice->slice_starts(0) - operand_start}, + {slice->slice_starts(0) - operand_start + 1}, + slice->slice_strides()))); + } + return true; + } + + return false; +} + +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + +StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( + HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!IsUnstridedSlice(slice)) { + return false; + } + HloInstruction* reshape = slice->mutable_operand(0); + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + HloInstruction* new_slice_operand = reshape->mutable_operand(0); + int64 slice_rank = ShapeUtil::Rank(slice->shape()); + std::vector sliced_dims; + for (int64 i = 0; i < slice_rank; ++i) { + if (slice->slice_starts(i) != 0 || + slice->slice_limits(i) != reshape->shape().dimensions(i)) { + sliced_dims.push_back(i); + } + } + + if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && + slice->slice_starts(0) == 0) { + const Shape& new_slice_shape = new_slice_operand->shape(); + const int64 rank = ShapeUtil::Rank(new_slice_shape); + std::vector new_slice_starts(rank, 0); + std::vector new_slice_stides(rank, 1); + std::vector new_slice_limits(new_slice_shape.dimensions().begin(), + new_slice_shape.dimensions().end()); + int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); + for (int64 i = rank - 1; i >= 0; --i) { + if (slice_elements >= new_slice_limits[i]) { + if (slice_elements % new_slice_limits[i] != 0) { + return false; + } + slice_elements /= new_slice_limits[i]; + } else { + new_slice_limits[i] = slice_elements; + slice_elements = 1; + } } - return replaced; + HloInstruction* new_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(new_slice_shape.element_type(), + new_slice_limits), + new_slice_operand, new_slice_starts, new_slice_limits, + new_slice_stides)); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); + return true; } + return false; } Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { @@ -1888,12 +2074,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } - auto is_unstrided_slice = [](const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); - }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && - is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { HloInstruction* operand_slice = slice->mutable_operand(0); std::vector new_slice_starts = slice->slice_starts(); std::vector new_slice_limits = slice->slice_limits(); @@ -1907,11 +2089,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_slice_starts, new_slice_limits, slice->slice_strides())); } - TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice)); + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); if (replaced) { return Status::OK(); } + TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + if (replaced) { + return Status::OK(); + } return Status::OK(); } @@ -2295,6 +2481,108 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { return ReplaceWithNewInstruction( sort, HloInstruction::CreateTuple(sort->operands())); } + if (!options_.enable_permutation_sort_replacement()) { + return Status::OK(); + } + // Check if we are sorting a permutation. In that case, we know that the keys + // will be sorted to the identity permutation, and we can represent the + // changes to the 'values' parameter as a scatter. + if (sort->operand_count() == 2 && + operand->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* other_sort = operand->operand(0); + // Check whether the 'values' parameter is the result of another sort with + // the same sort dimension. + if (other_sort->opcode() == HloOpcode::kSort && + other_sort->operand_count() >= 2 && + other_sort->dimensions(0) == dimension_to_sort && + other_sort->operand(operand->tuple_index())->opcode() == + HloOpcode::kIota) { + auto* iota = + Cast(other_sort->operand(operand->tuple_index())); + // The sort operand needs to be an integral iota, and the iota dimension + // needs to be the dimension that was sorted. + if (iota->iota_dimension() == dimension_to_sort && + ShapeUtil::ElementIsIntegral(iota->shape())) { + // We use the following construction method for a Scatter that applies + // the permutation from 'keys' to the 'values' parameter. + // - Take the "keys" parameter of the second sort and reshape it to have + // another "1" dimension at the end. + // - Concatenate it with iotas of the same extended shape with all + // different iota_dimensions except the dimension_to_sort in the order + // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and + // dimension_to_sort = 1, we would have concatenate of (iota with + // iota_dimension=0, keys, iota with iota_dimension = 2) + // - Use this as the indices parameter of scatter, and set updates + // of the scatter to be a reshaped 'values' parameter of sort (adding + // 'rank' many 1 dimensions at the end). + int64 rank = ShapeUtil::Rank(operand->shape()); + Shape extended_shape = operand->shape(); + extended_shape.add_dimensions(1); + extended_shape.mutable_layout()->add_minor_to_major(rank); + auto reshaped_permutation = computation_->AddInstruction( + HloInstruction::CreateReshape(extended_shape, operand)); + std::vector concat_operands; + for (int64 i = 0; i < rank; ++i) { + if (i == dimension_to_sort) { + concat_operands.push_back(reshaped_permutation); + } else { + concat_operands.push_back(computation_->AddInstruction( + HloInstruction::CreateIota(extended_shape, i))); + } + } + Shape concat_shape = operand->shape(); + concat_shape.add_dimensions(rank); + concat_shape.mutable_layout()->add_minor_to_major(rank); + auto scatter_indices = + rank > 1 ? computation_->AddInstruction( + HloInstruction::CreateConcatenate( + concat_shape, concat_operands, rank)) + : reshaped_permutation; + + // We don't care about the operand, it will be completely overridden by + // the updates. + auto scatter_operand = computation_->AddInstruction( + HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); + + // Construct the updates operand of scatter. + Shape update_shape = sort->operand(1)->shape(); + for (int64 i = 0; i < rank; ++i) { + update_shape.add_dimensions(1); + update_shape.mutable_layout()->add_minor_to_major(rank + i); + } + auto scatter_updates = + computation_->AddInstruction(HloInstruction::CreateReshape( + update_shape, sort->mutable_operand(1))); + + // Construct the updates computation, which simply replaces the operand + // values with the update values. + HloComputation::Builder b("update_replace_computation"); + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); + auto update_replace_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); + + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(rank); + for (int64 i = 0; i < rank; ++i) { + dim_numbers.add_update_window_dims(rank + i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + auto scatter = + computation_->AddInstruction(HloInstruction::CreateScatter( + sort->operand(1)->shape(), scatter_operand, scatter_indices, + scatter_updates, update_replace_computation, dim_numbers)); + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple( + {computation_->AddInstruction(HloInstruction::CreateIota( + operand->shape(), dimension_to_sort)), + scatter})); + } + } + } return Status::OK(); } @@ -2319,7 +2607,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } @@ -2468,13 +2756,13 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (!enable_conv_simplification_) { + if (!options_.enable_conv_simplification()) { return false; } // TODO(b/31337498): For now, we cowardly refuse to do this optimization in // layout-insensitive mode, for fear of adding nontrivial reshapes. - if (!is_layout_sensitive_) { + if (!options_.is_layout_sensitive()) { return false; } @@ -2564,9 +2852,9 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(input_shape, new_input_shape) || - !valid_bitcast_callback_(filter_shape, new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || + !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || + !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { return false; } @@ -2672,9 +2960,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run( - comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_strength_reduction_, enable_conv_simplification_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f..d2775b9fafa7e4c625f5d181114e80e7369f9c78 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,8 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloModulePass { +class AlgebraicSimplifierOptions { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform @@ -34,18 +33,63 @@ class AlgebraicSimplifier : public HloModulePass { using ValidBitcastCallback = std::function; + explicit AlgebraicSimplifierOptions( + ValidBitcastCallback valid_bitcast_callback) + : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + // If valid_bitcast_callback returns true, then the pass will replace reshapes + // and transposes with bitcasts. + const ValidBitcastCallback& valid_bitcast_callback() const { + return valid_bitcast_callback_; + } + + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + void set_is_layout_sensitive(bool is_layout_sensitive) { + is_layout_sensitive_ = is_layout_sensitive; + } + bool is_layout_sensitive() const { return is_layout_sensitive_; } + + // Enable dot simplification on platforms where it is profitable. + void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { + enable_dot_strength_reduction_ = enable_dot_strength_reduction; + } + bool enable_dot_strength_reduction() const { + return enable_dot_strength_reduction_; + } + + // Enable convolution simplification on platforms where it is profitable. + void set_enable_conv_simplification(bool enable_conv_simplification) { + enable_conv_simplification_ = enable_conv_simplification; + } + bool enable_conv_simplification() const { + return enable_conv_simplification_; + } + + // If enable_permutation_sort_replacement is true, a sort op that is known to + // sort a permutation will be replaced with a scatter op. + void set_enable_permutation_sort_replacement( + bool enable_permutation_sort_replacement) { + enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + } + bool enable_permutation_sort_replacement() const { + return enable_permutation_sort_replacement_; + } + + private: + ValidBitcastCallback valid_bitcast_callback_; + bool is_layout_sensitive_{false}; + bool enable_dot_strength_reduction_{true}; + bool enable_conv_simplification_{true}; + bool enable_permutation_sort_replacement_{false}; +}; + +// A pass which performs algebraic simplifications. +class AlgebraicSimplifier : public HloModulePass { + public: // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. If valid_bitcast_callback - // returns true, then the pass will replace reshapes and transposes with - // bitcasts. - AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction = true, - bool enable_conv_simplification = true) - : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + // transformation. Otherwise, layout is ignored. + explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) + : options_(options) {} ~AlgebraicSimplifier() override = default; absl::string_view name() const override { return "algsimp"; } @@ -54,14 +98,7 @@ class AlgebraicSimplifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; - ValidBitcastCallback valid_bitcast_callback_; - - // Enable dot simplification on platforms where it is profitable. - bool enable_dot_strength_reduction_; - - // Enable convolution simplification on platforms where it is profitable. - bool enable_conv_simplification_; + AlgebraicSimplifierOptions options_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 7b3e957fbcf9f4628c4aeb0c323d50d3ed36a4f2..8b8ba2a77d9bec7a6baf6929a0566906727be319 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -46,18 +45,22 @@ using ::testing::ElementsAre; namespace op = xla::testing::opcode_matchers; -AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } -AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase { + protected: + AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -67,18 +70,18 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { + auto m = CreateNewVerifiedModule(); Shape r0s32 = ShapeUtil::MakeShape(S32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -88,12 +91,11 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } @@ -114,8 +116,7 @@ TEST_F(AlgebraicSimplifierTest, SelectTrue) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -137,8 +138,7 @@ TEST_F(AlgebraicSimplifierTest, SelectFalse) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } @@ -158,14 +158,14 @@ TEST_F(AlgebraicSimplifierTest, SelectIdentical) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSelect); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param1); } // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( @@ -180,7 +180,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); HloInstruction* param = builder.AddInstruction( @@ -193,17 +193,17 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); - module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reduce(param, zero)); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -213,18 +213,18 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Constant())); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -239,17 +239,17 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -261,17 +261,17 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloComputation* add_computation = nullptr; @@ -284,7 +284,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); HloInstruction* param0 = builder.AddInstruction( @@ -297,17 +297,17 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateBroadcast(r2f32, zero, {}))}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -319,64 +319,64 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14, 3.14, 4}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); } // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -386,18 +386,18 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A - Const is canonicalized to A + (-Const). TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -407,18 +407,18 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); } // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -432,14 +432,13 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -447,6 +446,7 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { // Test that A/(B/C) is simplified to (A*C)/B. TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -460,14 +460,13 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -475,6 +474,7 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -492,15 +492,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -509,6 +508,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { // Test that A/exp(B) is simplified to A*exp(-B). TEST_F(AlgebraicSimplifierTest, DivOfExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -520,14 +520,13 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -535,6 +534,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { // Test that A/pow(B,C) is simplified to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfPower) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -548,14 +548,13 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -564,6 +563,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -577,14 +577,13 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -592,6 +591,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { // A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -602,11 +602,10 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Constant())); @@ -614,6 +613,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -627,10 +627,9 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Power(base, op::Multiply(exp1, exp2))); } @@ -638,6 +637,7 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + auto m = CreateNewVerifiedModule(); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -651,14 +651,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -668,18 +668,18 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A/1 is simplified to A for an array. TEST_F(AlgebraicSimplifierTest, DivOneArray) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -689,18 +689,18 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that complex(real(c), imag(c)) is simplified to c. TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); HloComputation::Builder builder(TestName()); @@ -713,18 +713,18 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -737,18 +737,18 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that imag(complex(r,i)) is simplified to i. TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -761,18 +761,18 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -788,18 +788,18 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } // Test that exp(A)/exp(B) is simplified to exp(A-B) TEST_F(AlgebraicSimplifierTest, ExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -813,14 +813,13 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -828,6 +827,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { // Test that exp(A)*exp(B) is simplified to exp(A+B) TEST_F(AlgebraicSimplifierTest, ExpMul) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -841,14 +841,13 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -856,6 +855,7 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { // Test that pow(exp(A), B) is simplified to exp(A*B) TEST_F(AlgebraicSimplifierTest, PowExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -867,14 +867,13 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -882,6 +881,7 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { // Test that ln(pow(A, B)) is simplified to ln(A)*B TEST_F(AlgebraicSimplifierTest, LnPow) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -893,14 +893,13 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -908,6 +907,7 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -917,19 +917,19 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B TEST_F(AlgebraicSimplifierTest, LnExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -945,14 +945,13 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -960,6 +959,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { // Test that pow(A, 0) where A is a scalar is simplified to the scalar // constant 1. TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -969,13 +969,12 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -984,6 +983,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). TEST_F(AlgebraicSimplifierTest, Pow0Vector) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -993,13 +993,12 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -1012,6 +1011,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { // Test that pow(A, 1) is simplified to A. TEST_F(AlgebraicSimplifierTest, Pow1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1021,19 +1021,19 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. TEST_F(AlgebraicSimplifierTest, Pow2) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1043,19 +1043,19 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } // Test that pow(A, -1) is simplified to 1/A. TEST_F(AlgebraicSimplifierTest, PowNegative1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1065,13 +1065,12 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -1081,6 +1080,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { } TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); @@ -1113,17 +1113,17 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1148,24 +1148,24 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); - module().AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1182,17 +1182,17 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); - module().AddEntryComputation(builder.Build()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + m->AddEntryComputation(builder.Build()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Pad(param, op::Constant())); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + HloPassFix simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -1206,39 +1206,39 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - module().AddEntryComputation(std::move(computation)); + m->AddEntryComputation(std::move(computation)); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloPassFix simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), op); + EXPECT_THAT(m->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } // Test that copies are removed. TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1246,18 +1246,18 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1268,24 +1268,27 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param)); - AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier1(options); + ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. EXPECT_THAT(computation->root_instruction(), op::Copy(param)); - AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options2(bitcasting_callback()); + options2.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier2(options2); + ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1293,19 +1296,19 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } // Test that empty operands of concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1322,15 +1325,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -1338,6 +1340,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { // Test that reduce of concat is simplified. TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r3f32 = ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); @@ -1363,7 +1366,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); @@ -1373,11 +1376,10 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { builder.AddInstruction(HloInstruction::CreateReduce( reduce_shape, Concatenate, zero, {1, 2}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -1387,6 +1389,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1401,20 +1404,20 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that concat with a scalar broadcast becomes a pad. TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); @@ -1427,17 +1430,17 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } // Test that a simplification which changes layouts is not performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1445,7 +1448,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1453,9 +1456,10 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -1464,6 +1468,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { // Test that a simplification which preserves layouts is performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1471,7 +1476,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1479,9 +1484,10 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1490,6 +1496,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { // Test that a reshape which could be replaced with a bitcast is not if // add_bitcasts is false. TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1502,13 +1509,14 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1516,6 +1524,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { // Test transforming reshapes and transposes of rng. TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); @@ -1532,11 +1541,11 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { ShapeUtil::MakeShape(F32, {4}), transpose)) ->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. @@ -1547,6 +1556,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1578,15 +1588,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - simplifier.Run(&module()).ValueOrDie(); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1597,6 +1608,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1611,15 +1623,16 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1635,13 +1648,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1655,19 +1669,21 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1681,19 +1697,21 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1706,19 +1724,19 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1733,18 +1751,20 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1757,13 +1777,12 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), @@ -1772,6 +1791,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {5}), "param0")); @@ -1780,20 +1800,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } // Test merging broadcast and reshape. TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); @@ -1802,19 +1822,19 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1823,20 +1843,20 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1845,14 +1865,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1860,6 +1879,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1868,14 +1888,13 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = @@ -1885,6 +1904,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1893,33 +1913,32 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction(HloInstruction::CreateIota( ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_TRUE( @@ -1927,18 +1946,18 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { } TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); auto result_shape = iota->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Iota()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -1948,37 +1967,37 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_EQ(Cast(computation->root_instruction()) @@ -1987,19 +2006,19 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); const int64 iota_dim = @@ -2009,19 +2028,19 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } @@ -2043,14 +2062,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2076,11 +2094,10 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); auto has_negative_padding = [](const HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2095,7 +2112,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -2110,14 +2127,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2133,14 +2149,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2162,14 +2177,13 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, /*start_indices=*/{2, 3}, /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); @@ -2178,6 +2192,55 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + const int64 dim2 = 13; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape, + /*start_indices=*/{0, 0, 0}, + /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {3600, 512}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {960, 512}), original_reshape, + /*start_indices=*/{0, 0}, + /*limit_indices=*/{960, 512}, /*strides=*/{1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); @@ -2185,14 +2248,86 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } +TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Tuple(op::Iota(), + op::Scatter(op::Iota(), + op::Concatenate(op::Iota(), op::Reshape()), + op::Reshape()))); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { + // Same as ReplacePermutationSortWithScatter except that the iota has F32 + // type. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { + // Same as ReplacePermutationSortWithScatter except that the sort dimensions + // don't match. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -2207,15 +2342,182 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { builder.AddInstruction(HloInstruction::CreateSort( ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, keys, {values0, values1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values0, values1)); } +// Test that A && True is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that True && A is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A && False is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that False && A is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that A || True is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that True || A is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that A || False is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that False || A is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Used for TEST_Ps that test merging (or not) of a kPad instruction into a // convolution's Window. struct ConvPaddingTestcase { @@ -2337,15 +2639,14 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(), lhs_pad, filter, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2455,15 +2756,14 @@ TEST_P(ConvFilterPaddingTest, DoIt) { input, rhs_pad, /*feature_group_count=*/1, window, dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2604,11 +2904,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + simplifier_options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } @@ -2724,20 +3025,19 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2763,16 +3063,15 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2782,7 +3081,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2837,8 +3136,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -2864,7 +3162,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2923,8 +3221,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -2954,12 +3251,11 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2970,6 +3266,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { // Dots add computations to the parent module. Test that, when the HloModule's // computations are updated, then iterator invalidation doesn't occur // when running on subsequent computations. + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); HloComputation::Builder builder(TestName() + ".Dot"); HloInstruction* x = @@ -2991,15 +3288,15 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - module().AddEmbeddedComputation(std::move(dot_computation)); - module().AddEntryComputation(call_builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEmbeddedComputation(std::move(dot_computation)); + m->AddEntryComputation(call_builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -3008,11 +3305,10 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -3021,6 +3317,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { // of its input equals the size of its output. In this case, the dynamic slice // is equal to its input. TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3032,10 +3329,9 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -3043,6 +3339,7 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3065,16 +3362,16 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( @@ -3085,12 +3382,11 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { builder.AddInstruction( HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_THAT(root->dimensions(), ElementsAre(2)); @@ -3098,6 +3394,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); @@ -3111,12 +3408,11 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); @@ -3124,6 +3420,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* iota = @@ -3131,12 +3428,11 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3144,6 +3440,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); HloInstruction* iota = @@ -3152,12 +3449,11 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3174,12 +3470,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Constant())); @@ -3196,12 +3491,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Constant())); @@ -3218,12 +3512,11 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } @@ -3238,17 +3531,102 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Parameter()); } +TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[2:3]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(1)); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_EQ(root->slice_starts(0), 1); + EXPECT_EQ(root->slice_limits(0), 2); +} + +TEST_F(AlgebraicSimplifierTest, NegateNegate) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + neg.0 = f32[2] negate(param.0) + ROOT neg.1 = f32[2] negate(neg.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + +TEST_F(AlgebraicSimplifierTest, NotNot) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = pred[2] parameter(0) + not.0 = pred[2] not(param.0) + ROOT not.1 = pred[2] not(not.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -3278,6 +3656,7 @@ class PadReduceWindowEffectiveBroadcastTest PadReduceWindowEffectiveBroadcastCase> {}; TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + auto m = CreateNewVerifiedModule(); const auto& param = GetParam(); // a and b are parallel bounds we can either turn into a B F S0 S1 or @@ -3326,7 +3705,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Window window = window_util::MakeWindow( @@ -3340,10 +3719,9 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3392,6 +3770,7 @@ class DotStrengthReductionTest public ::testing::WithParamInterface< ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + auto module = CreateNewVerifiedModule(); int m, k, n; bool transpose_lhs, transpose_rhs; PrimitiveType element_type; @@ -3421,10 +3800,9 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -3452,7 +3830,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloVerifiedTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // Test that we transform @@ -3460,6 +3838,7 @@ class DotOfConcatSimplificationTest // to // add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3498,10 +3877,9 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3519,6 +3897,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { // to // add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3562,10 +3941,9 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3590,6 +3968,7 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { // Test that DynamicUpdateSlice update param with any dimension equal to zero // gets removed. TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); HloInstruction* const operand = builder.AddInstruction( @@ -3602,11 +3981,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -3625,7 +4003,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloVerifiedTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -3634,6 +4012,7 @@ class DotOfGatherSimplificationTest // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3680,10 +4059,9 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3704,6 +4082,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3750,10 +4129,9 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc new file mode 100644 index 0000000000000000000000000000000000000000..c11452a6fbd49a1fc382d11d24a7d7b7eeab0bcc --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -0,0 +1,286 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace m = match; + +// If the argument instruction is a CRS in the sequence +// AR -> Convert -> Add -> CRS +// then return the AR in the sequence. +// TODO(b/117554291): Rewrite this to recognize more general patterns, +// not just the specific one of AR -> Add -> Convert -> CRS. +absl::optional MatchesArCrsPattern( + HloInstruction* instruction) { + HloInstruction *ar, *convert, *add, *crs; + if (Match(instruction, + m::CrossReplicaSum( + &crs, m::Add(&add, m::Op(), + m::Convert(&convert, + m::CrossReplicaSum(&ar, m::Op()))))) && + ar->users().size() == 1 && ar->shape().element_type() == BF16 && + convert->shape().element_type() == F32 && !crs->all_reduce_id()) { + return ar; + } + return absl::optional(); +} + +} // namespace + +absl::optional ArCrsCombiner::WhileFromBodyParameter( + HloInstruction* instruction) { + CHECK(HloOpcode::kParameter == instruction->opcode()); + HloComputation* computation = instruction->parent(); + auto caller_instructions = call_graph_->GetComputationCallers(computation); + if (caller_instructions.size() == 1) { + auto caller_instruction = caller_instructions[0]; + if (caller_instruction->opcode() == HloOpcode::kWhile) { + return caller_instruction; + } + } + return absl::optional(); +} + +std::vector ArCrsCombiner::GetAllTuples( + HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kTuple) { + return {instruction}; + } + if (instruction->opcode() == HloOpcode::kDomain) { + return GetAllTuples(instruction->operands()[0]); + } + if (instruction->opcode() == HloOpcode::kParameter) { + auto maybe_while = WhileFromBodyParameter(instruction); + if (!maybe_while) { + return {}; + } + auto while_instr = *maybe_while; + auto init_tuples = GetAllTuples(while_instr->while_init()); + auto body_tuples = + GetAllTuples(while_instr->while_body()->root_instruction()); + if (init_tuples.empty() || body_tuples.empty()) { + return {}; + } + init_tuples.insert(init_tuples.end(), body_tuples.begin(), + body_tuples.end()); + return init_tuples; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + std::vector result_tuples; + for (auto tuple : GetAllTuples(instruction->operands()[0])) { + auto tmp_tuples = + GetAllTuples(tuple->mutable_operand(instruction->tuple_index())); + if (tmp_tuples.empty()) { + return {}; + } + result_tuples.insert(result_tuples.end(), tmp_tuples.begin(), + tmp_tuples.end()); + } + return result_tuples; + } + return {}; +} + +bool ArCrsCombiner::TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs) { + auto tuples = GetAllTuples(tuple_shaped_instruction); + if (tuples.empty()) { + return false; + } + for (auto tuple : tuples) { + CHECK(tuple->opcode() == HloOpcode::kTuple); + if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), + tuple->mutable_operand(i2), + visited_pairs)) { + return false; + } + } + return true; +} + +/* static */ +bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2) { + ArCrsCombiner combiner(/*num_spatial_partitions=*/2); + auto module = i1->parent()->parent(); + CHECK_EQ(module, i2->parent()->parent()); + combiner.call_graph_ = CallGraph::Build(module); + absl::flat_hash_map visited_pairs; + return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs); +} + +bool ArCrsCombiner::InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs) { + if (i1 == i2) { + return true; + } + auto uid1 = i1->unique_id(); + auto uid2 = i2->unique_id(); + auto min_uid = std::min(uid1, uid2); + auto max_uid = std::max(uid1, uid2); + auto it = visited_pairs->find(min_uid); + if (it != visited_pairs->end() && max_uid == it->second) { + return true; + } + auto opcode1 = i1->opcode(); + auto operands1 = i1->operands(); + if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { + return false; + } + if (opcode1 == HloOpcode::kConstant || i1->IsCrossModuleAllReduce()) { + return i1->Identical( + *i2, + /*eq_operands=*/std::equal_to(), + /*eq_computations=*/std::equal_to(), + /*layout_sensitive=*/false); + } + visited_pairs->emplace(min_uid, max_uid); + for (int i = 0; i < operands1.size(); ++i) { + auto operand1 = operands1[i]; + auto operand2 = i2->operands()[i]; + if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) { + return false; + } + } + if (opcode1 == HloOpcode::kGetTupleElement) { + if (i1->tuple_index() == i2->tuple_index()) { + return true; + } + return TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + i2->tuple_index(), visited_pairs); + } + return true; +} + +void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + auto ar = MatchesArCrsPattern(instruction); + if (ar) { + all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + } + } + } +} + +void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); + + auto instr_0 = instruction_vec[0]; + auto add_0 = instr_0->users()[0]->users()[0]; + CHECK(HloOpcode::kAdd == add_0->opcode()); + + for (int i = 1; i < instruction_vec.size(); ++i) { + auto instr_i = instruction_vec[i]; + auto add_i = instr_i->users()[0]->users()[0]; + CHECK(HloOpcode::kAdd == add_i->opcode()); + absl::flat_hash_map visited_pairs; + if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { + all_reduce_map_.erase(it.first); + } + } + } +} + +StatusOr ArCrsCombiner::RewriteGraph() { + if (all_reduce_map_.empty()) { + return false; + } + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + for (auto all_reduce : instruction_vec) { + auto parent_computation = all_reduce->parent(); + auto convert = all_reduce->users()[0]; + auto add = convert->users()[0]; + auto crs = add->users()[0]; + + if (!computation_is_addition(all_reduce->called_computations()[0]) || + !computation_is_addition(crs->called_computations()[0])) { + continue; + } + HloInstruction* other_summand = (add->operands()[0] == convert) + ? add->operands()[1] + : add->operands()[0]; + // Remove the AllReduce and replace the CRS with: + // AllReduce - (other_summand * (num_spatial_partitions_ - 1)) + TF_CHECK_OK( + all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); + crs->set_all_reduce_id(all_reduce->all_reduce_id()); + auto new_shape = crs->shape(); + HloInstruction* to_subtract; + if (num_spatial_partitions_ == 2) { + to_subtract = other_summand; + } else { + Literal partitions_minus_1_lit = Literal(new_shape); + partitions_minus_1_lit.PopulateWithValue( + num_spatial_partitions_ - 1); + auto partitions_minus_1_const = parent_computation->AddInstruction( + HloInstruction::CreateConstant(partitions_minus_1_lit.Clone())); + to_subtract = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + new_shape, HloOpcode::kMultiply, other_summand, + partitions_minus_1_const)); + } + auto sub = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + new_shape, HloOpcode::kSubtract, crs, to_subtract)); + TF_CHECK_OK(crs->ReplaceAllUsesWith(sub)); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + } + } + + return true; +} + +StatusOr ArCrsCombiner::Run(HloModule* module) { + call_graph_ = CallGraph::Build(module); + + GroupAllReducesById(module); + + KeepProvablyEqualInstructionGroups(); + + return RewriteGraph(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a7ef76ec3b76972d1b2c7fb548cecfb9423160 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Combine an AllReduce and a CrossReplicaSum when they are close to each other +// in the graph, to use an efficient CrossReplicaSum implementation that +// fully utilizes the interconnect bandwidth. +class ArCrsCombiner : public HloModulePass { + public: + ArCrsCombiner(int num_spatial_partitions) + : num_spatial_partitions_(num_spatial_partitions) {} + absl::string_view name() const override { return "ar-crs-combiner"; } + StatusOr Run(HloModule* module) override; + + // Helper method to allow testing of InstructionsComputeSameValue. + static bool TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2); + + private: + // If the passed instruction is a while parameter, and the while body is only + // called by a single while instruction, return the while instruction. + absl::optional WhileFromBodyParameter( + HloInstruction* instruction); + + // Returns a vector of tuple instructions. + // If all instructions that flow to "instruction" are tuples, return them. + // Otherwise, return an empty vector. + std::vector GetAllTuples(HloInstruction* instruction); + + // Checks whether two different elements in the same tuple compute the same + // value. + bool TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs); + + // Returns whether the instructions i1 and i2 can be shown to evaluate to the + // same value. Handling WHILE requires recursion, which may cause us to visit + // the same instruction again. To avoid infinite loops, we pass a cache of + // visited instruction pairs. + bool InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs); + + // Populates all_reduce_map_. + void GroupAllReducesById(HloModule* module); + + // Looks at each AllReduce group in all_reduce_map_, and keeps only the + // groups for which it's safe to move the AllReduce later in the HLO graph. + void KeepProvablyEqualInstructionGroups(); + + // Performs the graph rewrite that eliminates the early AllReduce and turns + // the later CRS into an AllReduce. + StatusOr RewriteGraph(); + + int num_spatial_partitions_; + + // Map from all-reduce ids to the all reduce instructions. + absl::flat_hash_map> all_reduce_map_; + + std::unique_ptr call_graph_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d5eaf63ccf32cd78b8c11f12f9bccdfd1fec3e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -0,0 +1,415 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ArCrsCombinerTest : public HloTestBase {}; + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue( + i1, module->entry_computation()->parameter_instruction(0))); + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple1 = (f32[2,2]) tuple(%constant.f32) + %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile1) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile2) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile3) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1 + auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2 + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Subtract(op::CrossReplicaSum(), op::Constant()), + op::Subtract(op::CrossReplicaSum(), op::Constant()))); + auto sub = module->entry_computation()->root_instruction()->operands()[0]; + auto crs_after = sub->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); + for (int i = 0; i < replica_groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = replica_groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = replica_groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32.1, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32.2, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a645f98220ec445bb9bbdf2b9b842109..52ec1a794c5e9f4452a4bf2b648f453d8acfe976 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { @@ -38,11 +37,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -61,11 +61,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -84,11 +85,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -107,11 +109,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -130,11 +133,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -153,11 +157,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f7ac8f5482908af104554a1cf812370b9098cda7..08cf8026177d77ff98cca5e5d168ac3194936b35 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloVerifiedTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -59,14 +59,14 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { param0, param1, param2, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -101,14 +101,14 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { param1, param2, param3, param4, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie()); - for (auto* instruction : module().entry_computation()->instructions()) { + for (auto* instruction : m->entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 5f93740887aa7e61458990992fe0573883ff056d..4ce351acc2c359773e618da70360c96faf5ca379 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { +class BFloat16ConversionFoldingTest : public HloTestBase { protected: BFloat16ConversionFoldingTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -103,10 +103,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -138,10 +138,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -173,10 +173,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -203,10 +203,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { HloInstruction* convert1 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -216,7 +216,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("add"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -252,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cb075a5e38a5ea9db2ceb432b2b59f8db5e2e640..9f97d18c565c7915b9f9346f0c6330cdc3c707e9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloVerifiedTestBase { +class BFloat16NormalizationTest : public HloTestBase { protected: BFloat16NormalizationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -106,10 +106,10 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module)); + EXPECT_FALSE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -134,10 +134,10 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { HloInstruction* mul1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -164,10 +164,10 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { HloInstruction* sub1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -191,7 +191,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, reduce_comp_param0, reduce_comp_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(reduce_comp_builder.Build()); @@ -205,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -233,7 +233,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -263,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -272,7 +272,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -290,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -299,7 +299,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -314,7 +314,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); @@ -342,10 +342,10 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 0af71eaac96fca366e45430788e769c618f86bb5..5be7141aae423adb4fe2f39262e463ff25ae8234 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloVerifiedTestBase { +class BFloat16PropagationTest : public HloTestBase { protected: BFloat16PropagationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. @@ -121,10 +121,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,6 +136,62 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) { + auto module = CreateNewVerifiedModule(); + + auto sub_builder = HloComputation::Builder("max"); + HloInstruction* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); + HloInstruction* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); + sub_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1)); + auto max_computation = module->AddEmbeddedComputation(sub_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + HloInstruction* rw = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + shape, add, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + window, max_computation)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0})); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(rw)); +} + // Tests that side-effecting all-reduce should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { auto module = CreateNewVerifiedModule(); @@ -186,10 +242,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -242,10 +298,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -281,10 +337,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction* dot = builder.AddInstruction( CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -310,10 +366,10 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -321,7 +377,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { // Tests that BF16 is propagated properly through fused computations. TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -356,7 +412,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -369,7 +425,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { // Tests that changes to BF16 that cannot be propagated outside a fusion are // discarded. TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -393,7 +449,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -408,7 +464,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { // (BF16, BF16) fusion_computation(F32 a, F32 b) // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -439,7 +495,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -458,7 +514,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { // on_true and on_false must match, so that as long as one of them is F32, the // other must be F32 as well. TEST_F(BFloat16PropagationTest, SelectOverTuples) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -489,7 +545,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -502,7 +558,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { // Tests that BF16 is propagated properly through a while computation with // non-tuple input/output. TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -545,7 +601,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -561,7 +617,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { // made to the while body and thus the fusion node inside it. TEST_F(BFloat16PropagationTest, ConditionPreventsPropagationForFusionInsideWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -610,7 +666,7 @@ TEST_F(BFloat16PropagationTest, auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -622,7 +678,7 @@ TEST_F(BFloat16PropagationTest, // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -690,7 +746,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -709,7 +765,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // Tests that BF16 is not propagated through multiple whiles that invoke the // same computation as long as one while prevents the propagation. TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -820,7 +876,7 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -859,10 +915,10 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( bf16_shape, HloOpcode::kAdd, convert0, convert1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -895,10 +951,10 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -941,10 +997,10 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 5b48f10505e78c035608d4c575501e4623218987..2b9502f63a821f3675ddfb506f41bb2390cf4136 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -107,6 +108,21 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kSelect: case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + HloComputation* reduce_comp = hlo.called_computations()[0]; + for (HloInstruction* inst : reduce_comp->instructions()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) { + return false; + } + } + } + return true; + } default: break; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index ee4e5942731110e16c8396a824e6dbd19c9df607..8d7c62447852fd946440c41389300a92377c471f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -641,7 +641,7 @@ Status BufferAssignment::ComputeSummaryStats() { bool schedule_complete = true; for (const auto& computation : module_->computations()) { if (!computation->IsFusionComputation()) { - const std::vector* sequence = + const HloInstructionSequence* sequence = liveness_->hlo_ordering().SequentialOrder(*computation); if (sequence == nullptr) { schedule_complete = false; @@ -746,8 +746,7 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) { - BufferAssigner assigner(allow_input_output_aliasing, - allocate_buffers_for_constants, std::move(colorer), + BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), std::move(reuse_checker)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), @@ -1180,7 +1179,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); schedule.set_sequence(computation, *instruction_sequence); @@ -1215,7 +1214,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); @@ -1230,7 +1229,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, - HloInstructionSequence(*instruction_sequence), + *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1434,33 +1433,40 @@ BufferAssigner::MergeColocatedBufferSets( computation == module->entry_computation(); }; + std::vector set_can_be_merged(colocated_buffer_sets.size(), true); + + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (auto& buffer : colocated_buffer_sets[i]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { + set_can_be_merged[i] = false; + break; + } + } + } + // Returns true if the two colocated buffer sets (specified by their indices // into the colocated_buffer_sets) can be merged into a single set. auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, - &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs, entry parameters or - // constants. - // - // Buffer liveness does not report the correct live range for entry - // parameter and live out buffers so we have to special case them here. On - // backends that support constant buffer allocations, constant buffers are - // assigned globals in readonly storage so we can't merge colocated buffer - // sets containing constants with colocated buffer sets containing writing - // instructions or other constants. - // - // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to - // the caller of the executable so we can't write to entry parameters - // either, and the argument for not merging constants also applies to entry - // parameters. - for (int64 key : {i, j}) { - for (auto& buffer : colocated_buffer_sets[key]) { - if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kConstant) { - return true; - } - } + &set_can_be_merged](int64 i, int64 j) { + if (!set_can_be_merged[i] || !set_can_be_merged[j]) { + return true; } // Colocated sets satisfy the invariant that all buffers within a set have diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index d8e1612b899f10a5793f9c65c59a41024dfdddd1..0a9fdede803e84ca42472259084615c031b206eb 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -545,12 +545,10 @@ class BufferAssigner { ReuseAllocationFunction reuse_checker = nullptr); private: - BufferAssigner(bool allow_input_output_aliasing, - bool allocate_buffers_for_constants, + BufferAssigner(bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) - : allow_input_output_aliasing_(allow_input_output_aliasing), - allocate_buffers_for_constants_(allocate_buffers_for_constants), + : allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer), reuse_checker_(reuse_checker) {} virtual ~BufferAssigner() = default; @@ -640,10 +638,6 @@ class BufferAssigner { LogicalBuffer::Color::Hasher> SplitBuffersByColor(const absl::flat_hash_set& buffers); - // If true, buffer assignments assumes that input parameter buffers and output - // buffers can be shared if their sizes match. - bool allow_input_output_aliasing_; - // If true, allocate buffers for constant instructions. bool allocate_buffers_for_constants_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 327211d3efd24177a28cc2d08dc3c4fbf2fbaff9..8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloVerifiedTestBase { +class BufferAssignmentTest : public HloTestBase { protected: ~BufferAssignmentTest() override {} @@ -137,8 +137,7 @@ class BufferAssignmentTest : public HloVerifiedTestBase { } std::unique_ptr RunBufferAssignmentWithInstructionSequence( - HloModule* module, - absl::Span instruction_sequence, + HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { HloSchedule schedule(module); schedule.set_sequence(module->entry_computation(), instruction_sequence); @@ -334,16 +333,16 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -358,17 +357,17 @@ TEST_F(BufferAssignmentTest, BufferForConst) { LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -387,10 +386,10 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate, param0, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -410,10 +409,10 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -439,10 +438,10 @@ TEST_F(BufferAssignmentTest, Basic) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -538,7 +537,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -553,7 +552,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -599,7 +598,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -622,7 +621,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -671,10 +670,10 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -706,7 +705,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // param0[100x10] ---> (map x+1) // // Builds the map function. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto map_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); auto inner_last = map_computation->root_instruction(); @@ -725,7 +724,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -761,7 +760,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // out-of-order reductions could overwrite an element before a use.) // // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); @@ -784,7 +783,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -812,7 +811,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // const4[f32[4]] --- tuple --- while[condition, body] // // Builds the nested condition and body. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto condition_computation = module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); auto body_computation = @@ -840,7 +839,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -878,7 +877,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { } TEST_F(BufferAssignmentTest, ExampleConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto true_computation = module->AddEmbeddedComputation( BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); auto false_computation = module->AddEmbeddedComputation( @@ -905,7 +904,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -941,9 +940,9 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto neg = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -970,9 +969,9 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1003,9 +1002,9 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1040,9 +1039,9 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1075,9 +1074,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1107,9 +1106,9 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1145,9 +1144,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1160,7 +1159,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Verify that buffers for embedded computations are properly marked as // thread-local and that embedded parameters are not marked as // is_entry_computation_parameter. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto vec_shape = ShapeUtil::MakeShape(F32, {42}); auto scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1191,7 +1190,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Allocations for the map computation should be thread-local and not // live-out. @@ -1238,9 +1237,9 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { ShapeUtil::MakeShape(S32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1274,9 +1273,9 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1318,9 +1317,9 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1332,9 +1331,9 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), ShapeUtil::MakeShape(S32, {101})}), /*operands=*/{}, /*custom_call_target=*/"foo_function")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1347,7 +1346,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { TEST_F(BufferAssignmentTest, TupleCallAsOutput) { // Test a computation which returns a tuple call value. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1365,7 +1364,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1388,7 +1387,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { // B: call(C, param) // C: call(D, param) // D: param - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1427,7 +1426,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1461,9 +1460,9 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto bitcast = builder.AddInstruction( HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1488,9 +1487,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, pred_param, tuple_param0, tuple_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1526,9 +1525,9 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto copy = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape, HloOpcode::kCopy, tuple_element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1568,9 +1567,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module, /*alignment=*/1); + auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1589,7 +1588,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module, /*alignment=*/64); + assignment = RunBufferAssignment(module.get(), /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1632,10 +1631,10 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = @@ -1673,11 +1672,11 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module, {param, log, rev, neg, concat, root}); + module.get(), {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1698,7 +1697,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { } TEST_F(BufferAssignmentTest, PeakBuffersWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); HloComputation* condition; { @@ -1733,7 +1732,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1783,13 +1782,13 @@ ENTRY main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* constant_1 = - module().entry_computation()->GetInstructionWithName("constant.1.1"); + m->entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module().entry_computation()->GetInstructionWithName("constant.1.2"); + m->entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); { const BufferAllocation& allocation_for_const_1 = @@ -1818,7 +1817,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloVerifiedTestBase { +class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1853,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, absl::make_unique(schedule), ByteSizeOf, @@ -1878,7 +1877,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1917,8 +1916,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1974,20 +1973,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* param = - module().entry_computation()->parameter_instruction(0); + m->entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1995,7 +1993,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2042,20 +2040,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* constant = - module().entry_computation()->GetInstructionWithName("constant.42"); + m->entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -2063,7 +2060,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2121,7 +2118,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto token = builder.AddInstruction(HloInstruction::CreateToken()); @@ -2156,7 +2153,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module).status()); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry @@ -2164,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // nodes are traversed during BufferAssignment. TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -2175,12 +2172,12 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2202,7 +2199,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2234,8 +2231,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2247,7 +2244,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -2277,13 +2274,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); } - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2308,13 +2305,14 @@ ENTRY Main { )"; HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_text, config); + config.set_debug_options(GetDebugOptionsFromFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(hlo_text, config)); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); - HloComputation* main = module().entry_computation(); - HloComputation* callee = module().GetComputationWithName("Callee"); + HloComputation* main = m->entry_computation(); + HloComputation* callee = m->GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2338,7 +2336,7 @@ ENTRY Main { } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2385,40 +2383,41 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module); + RunCopyInsertion(module.get()); HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - schedule.set_sequence(module->entry_computation(), - {input1, weights1, one, output1, while1->operand(0), - while1, input0, weights0, zero, output0, - while0->operand(0), while0, gte0, gte1, root_add}); + schedule.set_sequence( + module->entry_computation(), + {input1, weights1, one, output1, while1->mutable_operand(0), while1, + input0, weights0, zero, output0, while0->mutable_operand(0), while0, + gte0, gte1, root_add}); // If this ASSERT fails, we constructed a bogus sequence above and this test // itself is buggy. TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run(module, - absl::make_unique(schedule), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2462,8 +2461,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 17e50905059ad2c92784d14132c1cb1f46f35ade..40825a78716b1c0b9fb0121787977d275891c0f8 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -514,7 +514,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -611,8 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - std::unique_ptr BuildModule(const bool update_uses_tuple_element1, - const bool fuse_gte0) { + std::unique_ptr BuildModule( + const bool update_uses_tuple_element1, const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index bdd5069632e84fe6c67ca129f726432479ac1b35..7987343bfaf1069fd550909d127e4b11f2124701 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -325,6 +325,15 @@ bool CallGraph::IsFlattened() const { return true; } +std::vector CallGraph::GetComputationCallers( + HloComputation* c) { + std::vector callers; + for (auto callsite : GetNode(c).caller_callsites()) { + callers.push_back(callsite.instruction()); + } + return callers; +} + std::pair CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, HloInstruction* b) const { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index cb56f4789d06ac33acdaadc8b619b9e37f683d58..05c7c998738f861ee804d1ec87bfa5fb17ddfb74 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -236,6 +236,10 @@ class CallGraph { // FlattenCallGraph. bool IsFlattened() const; + // Returns a vector of instructions calling the passed computation. + // (Often a vector of size 1.) + std::vector GetComputationCallers(HloComputation* c); + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 34f3f914d593bc603c4964663f9cafb70a136fd3..a3ac2568b0f3eec8556a42dbe3c2c64bd8564468 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloVerifiedTestBase { +class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -93,10 +93,10 @@ class CallGraphTest : public HloVerifiedTestBase { TEST_F(CallGraphTest, SingletonComputation) { // Test the call graph of a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -112,13 +112,13 @@ TEST_F(CallGraphTest, SingletonComputation) { TEST_F(CallGraphTest, UnreachableComputation) { // Test the call graph of a module with an entry computation and an // unreachable computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -134,13 +134,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { TEST_F(CallGraphTest, ParallelComputation) { // Test a call graph of a module with an entry computation which calls another // computation in a parallel context via kMap. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* map_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -163,13 +163,13 @@ TEST_F(CallGraphTest, ParallelComputation) { TEST_F(CallGraphTest, SequentialComputations) { // Test a call graph of a module with an entry computation which calls another // computation in a sequential context via kCall. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* called_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -196,7 +196,7 @@ TEST_F(CallGraphTest, SequentialComputations) { TEST_F(CallGraphTest, ContextBothComputations) { // Test a call graph of a module with an entry computation which calls another // computation in both a parallel and sequential context. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -239,7 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { TEST_F(CallGraphTest, ComputationWithConditional) { // Test a call graph of a module with a conditional. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* true_computation = module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); HloComputation* false_computation = @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(3, call_graph->nodes().size()); @@ -298,7 +298,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -418,7 +418,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -479,10 +479,10 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { TEST_F(CallGraphTest, VisitSingletonComputation) { // Test the call graph visitor with a call graph with a single node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -494,12 +494,12 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { TEST_F(CallGraphTest, VisitUnreachableComputation) { // Test the call graph visitor with a call graph with an unreachable node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // Test visitation of only reachable nodes. { @@ -531,9 +531,9 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e6b566543594a86eb5369ee9b7440f62618f6c5a..0b6e323f75c7a5dae127e20d2a4b92a83a72df3b 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloVerifiedTestBase; +using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -51,7 +51,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloInstruction* one = inner.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -79,7 +79,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // returns false should be identical to just returning false). TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create a lambda that calls a function that returns the false predicate. // Note we also use this lambda twice by reference, just to make the test a @@ -107,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -120,7 +120,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // whole pass. TEST_F(CallInlinerTest, InlineWithoutRunningPass) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { const Shape f32 = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( @@ -163,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..2662fe46705f4936ce0d654df0943e7d30890ebe --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_cache.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = counter++; + return id; +} + +} // namespace + +ExecutionHandle CompilationCache::Insert( + std::unique_ptr executable) { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = GetUniqueId(); + VLOG(2) << "inserting cache key: " << key; + CHECK_EQ(cache_.count(key), 0); + cache_.emplace(key, std::move(executable)); + + ExecutionHandle handle; + handle.set_handle(key); + return handle; +} + +StatusOr> CompilationCache::LookUp( + const ExecutionHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = handle.handle(); + VLOG(2) << "looking up cache key: " << key; + if (cache_.count(key) == 0) { + VLOG(2) << "cache key not found: " << key; + return InvalidArgumentStrCat("can not find executable with handle ", key); + } else { + auto& result = cache_.at(key); + VLOG(2) << "hit executable: " << result->module().name(); + return result; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..5f94def509d4d4a8950272cb498af5056a698ce0 --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A cache which stores Executables indexed by computation handle and version. +// +// TODO(b/119042872): Provide mechanism for removing computations from the +// compilation cache. +class CompilationCache { + public: + CompilationCache() {} + + ExecutionHandle Insert(std::unique_ptr executable); + + // Lookup the Executable for the specified handle in the cache. Return a + // shared_ptr to the Executable if it exists in the cache. + StatusOr> LookUp( + const ExecutionHandle& handle) const; + + protected: + mutable tensorflow::mutex mutex_; + + using CacheKey = int64; + + absl::flat_hash_map> cache_ + GUARDED_BY(mutex_); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 6d67f970020d278cc7bf61b56350200d3e5cb926..67132274c0dcbfda831c79836d052bb51b753ec7 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 80c630c6201503d88a690f04a88f6fca6f3a438a..8f08c244908efb823b3870c19bdc3491fa87d44f 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -110,6 +110,6 @@ Compiler::GetPlatformCompilers() { } AotCompilationOptions::AotCompilationOptions() - : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + : debug_options_(GetDebugOptionsFromFlags()) {} } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index c899ffb9dc562426ef14c0d414469c04debeec70..844b42a38d7539cccd5c4e30071c0ea6693e3bba 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -105,8 +105,6 @@ class ComputationPlacer { // Map from platform kind to computation placer singleton. static std::map* GetPlatformComputationPlacers(); - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167d47af3c92ed35fa52594fa5da1e4af..289eb6d90239a72ecc0f3312a7e0e8453f946858 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -37,7 +37,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalSimplifierTest : public HloVerifiedTestBase { +class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); @@ -96,25 +96,28 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { - HloComputation* computation = MakeConditional(&module()); - ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); + ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Constant())); } TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* true_op = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -125,11 +128,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -138,18 +142,19 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); false_computation->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2..7f7f1503a099b3a67ed22cb5978c01da6cf8ba88 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -51,7 +51,8 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(HloComputation* computation, + bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -59,18 +60,24 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(HloComputation* computation, + bool canonicalize_depthwise_filter = false) + : computation_(computation), + filter_expansion_(!canonicalize_depthwise_filter) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Whether filter expansion is required. + bool filter_expansion_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(HloComputation* computation, + bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -190,9 +197,49 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { HloInstruction* filter_mask = GetExpandedFilterMask( filter->shape(), input_feature_dim, output_feature_dim, group_count, add); HloInstruction* expanded_filter; - // We want to repeat 'filter' in the 'input_feature_dim' dimension - // 'group_count' times. + if (group_size == 1) { + bool depthwise_separable = + (group_count == filter->shape().dimensions(output_feature_dim)); + // If the code generator handles depthwise separable convolutions + // inherently, then no filter expansion is needed. + if (!filter_expansion_ && depthwise_separable) { + const int64 old_kernel_input_feature_dimension = + dim_numbers.kernel_input_feature_dimension(); + const int64 old_kernel_output_feature_dimension = + dim_numbers.kernel_output_feature_dimension(); + + // For depthwise convolutions, we want the kernel input feature dimension + // to be smaller than the output feature dimension. If that's not the + // case, we swap the dimensions. + if (old_kernel_input_feature_dimension > + old_kernel_output_feature_dimension) { + Shape reshaped_filter_shape = filter->shape(); + auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); + std::swap(dimensions[old_kernel_input_feature_dimension], + dimensions[old_kernel_output_feature_dimension]); + + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + dim_numbers.set_kernel_input_feature_dimension( + old_kernel_output_feature_dimension); + + dim_numbers.set_kernel_output_feature_dimension( + old_kernel_input_feature_dimension); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), + reshaped_filter, group_count, convolution->window(), dim_numbers, + convolution->precision_config()); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + } + return Status::OK(); + } + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. Shape reshaped_filter_shape = ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); auto reshaped_filter = @@ -237,7 +284,7 @@ StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(comp, filter_expansion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7..cb6bc04c00a2ff10f970da2a07fb540a561dad5a 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -27,7 +27,8 @@ namespace xla { // convolutions with feature_group_count = 1. class ConvolutionFeatureGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter() {} + ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) + : filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { return "convolution-feature-group-converter"; @@ -36,6 +37,9 @@ class ConvolutionFeatureGroupConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + // Tells whether filter expansion is required. + bool filter_expansion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 4e547d925f62dce1d2dd23a39a28ca8c23ba9f2f..df6059663876dfde71f4c75d3931b3d2de72c1df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -442,7 +442,6 @@ class CopyRemover { const HloOrdering& ordering, HloModule* module) : module_(module), alias_analysis_(alias_analysis), - ordering_(ordering), buffer_value_tracker_(*module, alias_analysis, ordering) {} // Try to elide the given copy. The copy is elided if the instruction is not @@ -1003,7 +1002,6 @@ class CopyRemover { HloModule* module_; const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; // Object tracking the HLO values contained in each HLO buffer. BufferValueTracker buffer_value_tracker_; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4533ebb99bbba854a029fb8a9a1e31b023be720d..e4e9d7ba05c115be9dd0eb53ebd7de208d514efb 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -403,7 +403,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(CreateNewModule()) {} + WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1362,7 +1362,7 @@ TEST_F(CopyInsertionTest, CrossingParameters) { // | / \ | // | / \| // (p1 , p0) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1395,7 +1395,7 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1428,7 +1428,7 @@ TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1461,7 +1461,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1496,7 +1496,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { // | | | // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1534,7 +1534,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { // | Add----+ // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1569,7 +1569,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1632,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1693,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1783,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -1896,7 +1896,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); @@ -1936,7 +1936,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); @@ -2003,7 +2003,7 @@ std::unique_ptr MakeBenchmarkWhileBody( void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { tensorflow::testing::StopTiming(); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); CopyInsertion copy_insertion; const Shape element_shape = ShapeUtil::MakeShape(F32, {}); std::vector tuple_params(num_tuple_inputs); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 36e25cbe678e03f511934eb00af8c3834de2c63e..ce4c2a9cc69240b9565b35a3f2504d7fc9373917 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -96,6 +96,7 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -824,7 +825,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -846,7 +846,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -887,7 +886,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -961,17 +959,16 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -997,7 +994,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 73b03440cbb936017257b8a92f16dcc25d41e21c..796a7cf94d02b0ad42366387a9d3f8d589b8840a 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -61,19 +61,6 @@ Disabling these as a starting point. // TODO(b/64227304) Creating a custom pass pipeline will replace this. namespace { -class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { - public: - FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) - : llvm::legacy::FunctionPassManager(m), - disable_expensive_passes_(disable_expensive_passes) {} - void add(llvm::Pass* p) override { - llvm::legacy::FunctionPassManager::add(p); - } - - private: - bool disable_expensive_passes_; -}; - class FilteredPassManager : public llvm::legacy::PassManager { public: explicit FilteredPassManager(bool disable_expensive_passes) @@ -96,8 +83,7 @@ class FilteredPassManager : public llvm::legacy::PassManager { std::unique_ptr CompilerFunctor::operator()( llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); - FilteredFunctionPassManager function_passes(&module, - disable_expensive_passes_); + llvm::legacy::FunctionPassManager function_passes(&module); VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 2083f440fdd971db1b675d005664d25e6de53dbe..c58175428fea6a2d38253c35de598b99a4281bf1 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloVerifiedTestBase { +class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -87,7 +87,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -150,7 +150,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 4ce5a8a29255a763c83941efb6de9b7c652cedb4..2bf24c15c1f050b200b1d9af2d95286f9a9dbe4c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -76,6 +76,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -249,6 +250,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + pipeline.AddPass(); pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner @@ -268,10 +270,10 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - pass.AddPass( - /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }, - /*enable_dot_strength_reduction=*/false); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_enable_dot_strength_reduction(false); + pass.AddPass(options); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -334,10 +336,11 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pass.AddInvariantChecker( /*layout_sensitive=*/true, /*allow_mixed_precision=*/false); - pass.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_dot_strength_reduction(false); + pass.AddPass>(options); pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } @@ -587,9 +590,9 @@ StatusOr> CpuCompiler::RunBackend( // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module.get(), BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( @@ -779,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction())); + ScheduleModule(module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index c9fb34be1cd582c71618c770c892058c233c571a..c085f85fb73e98e4c7ba15af8db8bb19c2499f5f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloVerifiedTestBase { +class CpuCopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -65,7 +65,7 @@ class CpuCopyInsertionTest : public HloVerifiedTestBase { TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). Each constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 3); @@ -103,7 +103,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { // Test a kCall instruction which calls a computation which produces a three // element tuple: one is a constant, one is a parameter, and one is produced // in the computation. The constant and parameter should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 29abf38e439d919ff93629ed992cb3ff93a929bd..818b2b0d0db2893e11fa46c7867e6c74bbbb6905 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -51,8 +51,7 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3c3c047bfe8ee0d1ad90ede2432a86264f47870b..3b91b15ba9b5603b50f78f489e9a3fdad354c083 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -49,7 +49,7 @@ class CpuExecutable : public Executable { public: CpuExecutable(std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index e6b6fcdf684eadb3702e490bbe24dbb7b3b52ec7..9cbfb88834bf51f4df54e97efe6cd7bf88b12334 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloVerifiedTestBase { +class CpuHloSupportCheckerTest : public HloTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(CpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 7d99b914d4f5e5d27722bcd098d2ae0c54a36a23..c77d5988ba3d204a6e9da2ff1337d68c44c19e62 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -58,7 +58,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -98,7 +98,7 @@ TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -119,7 +119,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -138,7 +138,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -157,7 +157,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -350,7 +350,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -392,7 +392,7 @@ TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -466,7 +466,7 @@ TEST_F(OpcodeFusionTest, Slice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -641,7 +641,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { builder.AddInstruction( HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -670,7 +670,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { builder.AddInstruction(HloInstruction::CreateBinary( large_shape, HloOpcode::kAdd, small_exp, large_param)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 97659b88a7974d7caf91ab0d4741f3635e4dae4a..6c61b64758ede160e2d50e4429590a789ec253c3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b8ace5702688096822573c7afae234cbcbe77b28..92debb83e33b1400a59e5eef0f90971392ab7b22 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,7 +22,6 @@ limitations under the License. namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; -const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index d6968323f337d83e41b5e031cc49fab5b6a17b21..cf97a8bde0757b67bdea62c30ea0e8e63161c573 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -111,7 +111,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -140,7 +140,7 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { @@ -1379,33 +1379,6 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -// Fills up the free variables in 'index_with_free_var' with values from -// 'filler_index'. The size of free variables must be the same as the -// size of 'filler_index'. -// -// This is often used after dimension reduction, where -// 'index_with_free_var' has one or more dimensions reduced, which serves as -// free variables (represented as nullptr). For example, if we have a 4 -// dimensional input and index for the dimension being reduced is -// 2 (third dimension), we will have an index like [i, j, NULL, k] -// after reduced dimension. -// -// Here we fill up that free variable by 'filler_index', which contains -// the value in the reduced dimension. -static llvm_ir::IrArray::Index FillReducedDimensionIndex( - llvm_ir::IrArray::Index index_with_free_var, - llvm_ir::IrArray::Index filler_index) { - llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); - - for (size_t i = 0; i < index_with_free_var.size(); ++i) { - if (index_with_free_var[i] == nullptr) { - index_with_free_var[i] = *it++; - } - } - CHECK(filler_index.end() == it); - return index_with_free_var; -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); return EmitTargetAddressForOp(parameter); @@ -1536,7 +1509,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs, rhs}, {lhs->getType()}, b); @@ -1551,7 +1525,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs, rhs}, {lhs->getType()}, b); @@ -2192,14 +2167,6 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { return Status::OK(); } -// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. -static const HloInstruction* StripTranspose(const HloInstruction& hlo) { - if (hlo.IsRank2Transpose()) { - return hlo.operand(0); - } - return &hlo; -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 136b88ff75ea8a5f48b42d3476219f18f5ecb39a..f529c613a3de62996feeca854213155df5943e7b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -101,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -467,9 +467,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // profiling a computation. class ProfilingState { public: - ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} - ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) - : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false) {} + explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); @@ -494,9 +493,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // intrinsic? bool use_rdtscp_; - // The argument which corresponds to the profile counter buffer. - llvm::Value* prof_counters_; - // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fad76338a57cd9eb21d9469ca8552efa8ea0129b..f0b65046c14ccec5336abf7c4d05d1d755f783bd 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class ParallelTaskAssignmentTest : public HloVerifiedTestBase { +class ParallelTaskAssignmentTest : public HloTestBase { protected: const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = cpu::CpuExecutable::ShapeSizeBytes; @@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { + : HloTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -57,8 +57,9 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -84,8 +85,9 @@ TEST_F(ParallelTaskAssignmentTest, } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -100,8 +102,9 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -116,8 +119,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 669eeb95f3299623a7556bfbb8045fd77f5d0745..722aa3120ef4d8c957873ac58c361f19632dde1f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -41,61 +42,60 @@ void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { std::sort(row_to_sort, row_to_sort + num_elements); } -// For floating point numbers, we want a total order comparator. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. Also we want to have a stable sort, so if the keys are the -// same, we compare the index values. -template -bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { - bool lhs_is_negative = std::signbit(lhs); - bool rhs_is_negative = std::signbit(rhs); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(lhs); - bool rhs_nan = std::isnan(rhs); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; +// We would like a total order of floating point numbers so that the +// sort has a predictable behavior in the presence of NaNs. Rather +// than using floating point comparison, we use the following trick: +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? 0x7FFFFFFF - x : x; +// then y is ordered as an int32 such that finite values have the +// obvious order, -0 is ordered before 0, and -NaN and NaN appear at +// the beginning and end of the ordering. +template +CastType Convert(KeyType value) { + CastType casted_value; + memcpy(&casted_value, &value, sizeof(CastType)); + if (casted_value < 0) { + return static_cast(std::numeric_limits::max()) - + casted_value; } - if (lhs != rhs) { - return lhs < rhs; - } - return lhs_index < rhs_index; + return casted_value; +} + +template +bool LessThan(KeyType lhs, KeyType rhs) { + return Convert(lhs) < + Convert(rhs); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), lhs.second, - Eigen::half_impl::half_to_float(rhs.first), rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), + Eigen::half_impl::half_to_float(rhs.first)); + }); } template diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 1a3d82de954318368d61e3feeb0345dc592dcd8b..7d8e51f909e3db699b745f94a6c625407bc4a6e3 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloVerifiedTestBase { +class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloVerifiedTestBase { +class ShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { +class RandomShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4b129c95d46d8b5a119e5d23eef387daf7863cce..382dfd0d99df87bbadfe541ddaa32cd6da8e8068 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,7 +48,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 18ee25ba9158c28baaf01492c290638b9673f1ec..f8f5f392da8ab3348e63185aecf7b639daacaa42 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 00a7aa2ad2f6bac4877302296ccb76222557535c..e30f95311fce229f9c559d3bb40142151e8bf3e3 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 1deb412064b02988a8d4a6d726969c948d354d47..04a81dfd35f459ff1fdb3181dc8fc65c62a37d4f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloVerifiedTestBase { +class CpuFusionTest : public HloTestBase { protected: CpuFusionTest() {} @@ -57,11 +57,11 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -104,11 +104,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -131,7 +131,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal.shape(); @@ -183,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -250,12 +250,12 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); // Create computation and module. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -310,11 +310,11 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate1, negate2, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index a434c04a980b9b3cd849792b97a0d9e965ba09f2..9b10c49f4f547edfb2164f98c49cceb031148bdc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 3b87683ffffefd2aa24dd234cc072425bef00a24..fa0e09ff6b5694c0e97963b83c6e541b858a1376 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -63,7 +63,7 @@ CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -104,14 +104,14 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [4 x i8] -CHECK: private constant [8 x i8] +CHECK-DAG: private constant [4 x i8] +CHECK-DAG: private constant [8 x i8] CHECK-NOT: private constant [4 x i8] CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index b35fd9dad877c319c3d0110c96a00aeefa78769e..a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef..64fb50318394918b277fd717994f5366d762ac36 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -18,19 +18,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloVerifiedTestBase { +class DefuserTest : public HloTestBase { protected: // Returns the number of fusion instructions in the module. - int FusionCount() { + int FusionCount(const HloModule* m) { int count = 0; - for (HloComputation* computation : module().computations()) { + for (HloComputation* computation : m->computations()) { if (computation->IsFusionComputation()) { count++; } @@ -43,6 +43,7 @@ class DefuserTest : public HloVerifiedTestBase { }; TEST_F(DefuserTest, NoFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -51,13 +52,14 @@ TEST_F(DefuserTest, NoFusionInstruction) { builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - module().AddEntryComputation(builder.Build()); - EXPECT_EQ(0, FusionCount()); + m->AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount(m.get())); - EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie()); } TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -66,21 +68,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -91,21 +94,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add(op::Parameter(), op::Parameter()))); } TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -128,22 +132,23 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction( {add2, constant, div, mul, sub, negate, add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, MultipleFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -166,7 +171,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add2, constant, div, mul}, HloInstruction::FusionKind::kLoop); computation->CreateFusionInstruction({sub, negate, add}, @@ -174,15 +179,16 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, NestedFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -193,7 +199,7 @@ TEST_F(DefuserTest, NestedFusionInstructions) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); auto outer_fusion = computation->CreateFusionInstruction( {negate, add}, HloInstruction::FusionKind::kLoop); HloInstruction* fused_negate = outer_fusion->fused_expression_root(); @@ -203,9 +209,9 @@ TEST_F(DefuserTest, NestedFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); } diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index b3549acfc291a54b2345b006310613c3a45a4b47..ed37099a5428075928ec98b134632867d58bbfe7 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/defuser.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" namespace xla { @@ -45,6 +46,7 @@ class ControlDepRemover : public HloModulePass { Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 4159aa281fa2b66d310d7c135f123a5a3bb83270..d6371283221b63b30f968929fe2807eae3f22df0 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -108,6 +108,7 @@ class DfsHloVisitorBase { virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 4cd10ab06cd3b804406607212d3f3c316d6cff95..e57184f639f4f2c618b980a5082381f4b9c28b19 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -203,6 +203,9 @@ class DfsHloVisitorWithDefaultBase Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } + Status HandleGetDimensionSize(HloInstructionPtr get_size) override { + return DefaultAction(get_size); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8bfc8905064bcd7b68fe259fbcc1546ff083dbd --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +Status DynamicParameterBinding::Bind( + const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) { + auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter); + TF_RET_CHECK(result.second); + return Status::OK(); +} + +absl::optional +DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { + auto param_iter = bindings_.find(dynamic_dimension); + if (param_iter == bindings_.end()) { + return absl::nullopt; + } + return param_iter->second; +} + +DynamicParameterBindingProto DynamicParameterBinding::ToProto() const { + DynamicParameterBindingProto result; + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + DynamicParameterBindingProto::Binding binding_proto; + binding_proto.set_dynamic_param_num(dynamic_param.parameter_num); + for (int64 i : dynamic_param.parameter_index) { + binding_proto.add_dynamic_param_index(i); + } + + binding_proto.set_target_param_num(dynamic_dimension.parameter_num); + + for (int64 i : dynamic_dimension.parameter_index) { + binding_proto.add_target_param_index(i); + } + + binding_proto.set_target_param_dim_num(dynamic_dimension.dimension); + result.add_entries()->Swap(&binding_proto); + } + return result; +} + +StatusOr DynamicParameterBinding::CreateFromProto( + const DynamicParameterBindingProto& proto) { + DynamicParameterBinding result; + for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) { + int64 dynamic_param_num = binding.dynamic_param_num(); + ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(), + binding.dynamic_param_index().end()); + int64 target_param_num = binding.target_param_num(); + ShapeIndex target_param_index(binding.target_param_index().begin(), + binding.target_param_index().end()); + int64 target_dim_num = binding.target_param_num(); + + TF_RETURN_IF_ERROR( + result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, + DynamicDimension{target_param_num, target_param_index, + target_dim_num})); + } + + return result; +} + +string DynamicParameterBinding::ToString() const { + std::vector pieces; + pieces.push_back("DynamicParameterBinding: "); + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + pieces.push_back(absl::StrFormat( + " -- Input param number %lld at %s has dim %lld as dynamic" + " dimension, which is represented by param number %lld at " + "%s", + dynamic_dimension.parameter_num, + dynamic_dimension.parameter_index.ToString(), + dynamic_dimension.dimension, dynamic_param.parameter_num, + dynamic_param.parameter_index.ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + +Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const { + for (const auto& binding : bindings_) { + TF_RETURN_IF_ERROR(fn(binding.second, binding.first)); + } + return Status::OK(); +} + +Status DynamicParameterBinding::Verify(const HloModule& module) const { + const HloComputation* entry = module.entry_computation(); + return ForEachBinding([&](const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) + -> Status { + TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), + dynamic_parameter.parameter_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(), + dynamic_dimension.parameter_index)); + TF_RET_CHECK( + dynamic_dimension.dimension < + ShapeUtil::Rank(ShapeUtil::GetSubshape( + entry->parameter_instruction(dynamic_dimension.parameter_num) + ->shape(), + dynamic_dimension.parameter_index))); + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding) { + out << binding.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..dd474d8eed1b2c30ddb8f624a864198c74eacaba --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; +// We currently use an explicit API that takes an extra parameter to indicate +// the runtime size of a dynamic dimension. DynamicParameterBinding indicates +// the relationship between parameter: We can have a dynamic parameter that +// points to another target parameter to indicate that the target parameter is +// dynamic. +// +// +// TODO(b/119520625): Remove this API once we have more dynamic shape infra +// ready. +class DynamicParameterBinding { + public: + // DynamicParameter represents a special parameter that is used to represent + // the runtime size of a dimension of another parameter. A dynamic parameter + // has to be a scalar value. + struct DynamicParameter { + // The parameter number of dynamic parameter. + int64 parameter_num; + // The index of the parameter. + ShapeIndex parameter_index; + }; + + // DynamicDimension represents a dimension whose size is determined at + // runtime. A DynamicDimension's runtime size is determined by the binded + // DynamicParameter using `DynamicParameterBinding::Bind` method. + struct DynamicDimension { + // The parameter number of dynamic dimension. + int64 parameter_num; + // The subshape index of the parameter. + ShapeIndex parameter_index; + // The dimension number in the subshape. + int64 dimension; + + // "friend" keyword are added so these functions can be found by ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.parameter_num, m.parameter_index, + m.dimension); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.parameter_num == rhs.parameter_num && + lhs.parameter_index == rhs.parameter_index && + lhs.dimension == rhs.dimension; + } + }; + + DynamicParameterBinding() = default; + + virtual ~DynamicParameterBinding() = default; + + // Adds binding which indicates that the dimension indicated by + // `dynamic_dimension` is dynamic, and its runtime size is represented by + // `dynamic_parameter`. + Status Bind(const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension); + + // Returns the parameter and the index representing the runtime size of + // dimension `dim_num` of parameter `param_num` at `param_index`. + // + // Returns nullopt if the binding is not set. + absl::optional GetBinding( + const DynamicDimension& dynamic_dimension); + + using BindingFn = + std::function; + + // Iterate through each binding. + Status ForEachBinding(BindingFn fn) const; + + DynamicParameterBindingProto ToProto() const; + + static StatusOr CreateFromProto( + const DynamicParameterBindingProto& proto); + + string ToString() const; + + // Verifies that the given binding is valid for the given module. + // Specifically, the binding's parameter and parameter size should be valid. + Status Verify(const HloModule& module) const; + + private: + // Keeps track of mappings from DynamicDimension to DynamicParameter. The + // direction of is chosen so that we can easily query if a dimension is + // dynamic and which dynamic parameter represents the real size of that + // dimension. + absl::flat_hash_map bindings_; +}; + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..83a6d83dffde7995bd8e43917d13c5fd2705ba6f --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class DynamicParameterBindingTest : public HloTestBase {}; + +TEST_F(DynamicParameterBindingTest, SimpleBinding) { + // 'b' is a dynamic shape; 'a' represents the real size of b's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[10] parameter(1) + ROOT root = (f32[], f32[10]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBinding) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both + // dimensions. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10, 10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10, 10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 1})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + + TF_EXPECT_OK(binding.Verify(*module)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 515267edd7caf42e04ebe638b99006db8967ea30..00bb430206afdb81f9d101c0a5b2b4cf907b447a 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -1671,26 +1672,66 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( b_->SetInsertPoint(init_block); + // Assign a unique id for each *different* operand, and count how often each + // operand is used. If all operands are different, the usage count will be 1 + // for each operand. + absl::flat_hash_map to_unique_operand_id; + std::vector operand_usage_count; + for (const auto* operand : hlo->operands()) { + if (to_unique_operand_id.contains(operand)) { + ++operand_usage_count[to_unique_operand_id[operand]]; + } else { + int64 unique_operand_id = to_unique_operand_id.size(); + to_unique_operand_id[operand] = unique_operand_id; + operand_usage_count.push_back(1); + } + } + + // To avoid that we emit the same operand more than once, we create one basic + // block for each *different* operand with a PHI node for the different source + // index inputs. + std::vector emit_operand_blocks( + to_unique_operand_id.size(), nullptr); + std::vector source_index_phis(to_unique_operand_id.size(), + nullptr); + for (const auto* operand : hlo->operands()) { + int64 operand_id = to_unique_operand_id[operand]; + if (emit_operand_blocks[operand_id] != nullptr) { + continue; + } + + emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand_id", operand_id), b_); + auto saved_insert_point = b_->GetInsertPoint(); + llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); + source_index_phis[operand_id] = + PHI(source_index.GetType(), operand_usage_count[operand_id]); + auto operand_index = source_index; + operand_index[concat_dim] = source_index_phis[operand_id]; + + // Create the terminator of the block before calling operand generators, + // because they require non-degenerate basic blocks. + b_->SetInsertPoint(llvm::BranchInst::Create( + exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(operand_index)); + output->addIncoming(value, b_->GetInsertBlock()); + b_->SetInsertPoint(init_block, saved_insert_point); + } + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, - false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - b_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, b_->GetInsertBlock()); + int64 operand_id = to_unique_operand_id[operand]; + source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + b_->GetInsertBlock()); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. @@ -1815,8 +1856,6 @@ StatusOr ElementalIrEmitter::EmitElementalGather( // Clamp the gather index so that the gather region fits in the operand. // gather_dim_component_extended_inbound = // clamp(gather_dim_component_extended, 0, largest_valid_start_index); - - // TODO(b/111078873): This is implementation defined behavior. bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); auto gather_dim_component_extended_inbound = EmitIntegralMin( index.GetConstantWithIndexType(largest_valid_start_index), diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb..10b8c01ff1383658fcfb2271c177ba54347f985a 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 3a6780f2a67f230cae626ea00cfbf93b4e60d968..b34bca55a48b113c325dbf28c03f7a0f5b71f658 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -61,7 +61,7 @@ struct ExecutionOutput { class Executable { public: explicit Executable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), @@ -162,7 +162,7 @@ class Executable { return hlo_profile_printer_data_ != nullptr; } - const HloModule& module() const { return *hlo_module_; } + HloModule& module() const { return *hlo_module_; } const bool has_module() const { return hlo_module_ != nullptr; } @@ -199,7 +199,7 @@ class Executable { // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. - const std::unique_ptr hlo_module_; + const std::unique_ptr hlo_module_; // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8..8eeb930b48165a2e3c622581e05cb5f7063fa1fa 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloVerifiedTestBase { +class FlattenCallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -108,7 +108,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module); + std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -149,7 +149,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // Test corner case of a computation used as a body and a loop condition. TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation; { HloComputation::Builder builder(TestName() + ".cond"); @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -201,7 +201,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { // C // TEST_F(FlattenCallGraphTest, FlattenCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* c_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -224,7 +224,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { } TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* sub_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1e8435fe542f2b65a11e256453cf911c5e6e833b..bfd1b6cb1492f5cb709e2ecefe73782094e26f5e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -111,7 +111,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -463,7 +462,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -627,7 +626,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -702,6 +701,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", @@ -849,7 +849,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -909,7 +908,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -1036,6 +1034,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index fa3afa6a5d318c399dc38e8934199b5a1393669e..af9303a5b761b99705945f1c02303156e3f874de 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; +class CudnnConvPadForTensorCoresTest : public HloTestBase {}; TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index c46672c598b27670c56b3efa4775be8fea1fc6ac..e81850db69edced29ea31bb2a526b0503bf8a453 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -77,7 +77,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { return false; } - if (window_util::HasWindowReversal(conv->window())) { + // CuDNN can perform either cross correlation (no reversal), + // or convolution (all dimensions reversed). + if (dnums.input_spatial_dimensions_size() == 2 + ? !window_util::AllOrNoneReversed(conv->window()) + : window_util::HasWindowReversal(conv->window())) { return false; } return true; @@ -254,7 +258,7 @@ MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // TODO(b/119479517): Theoretically cuDNN supports grouped convolutions also // for the backward input convolution, but at least for now with version 7.1.4 // it is slower. This needs to be re-evaluated for future cuDNN versions. // Note that we already have the necessary code down below, the only thing to diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 87a835f2504068548159ef32b276201c936fa385..443883a89f66a747def1049bc5afb53fec3c2409 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -34,11 +34,11 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvRewriterTest : public HloVerifiedTestBase { +class CudnnConvRewriterTest : public HloTestBase { public: CudnnConvRewriterTest() - : HloVerifiedTestBase(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false) { + : HloTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -118,10 +118,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { metadata.set_op_name("foo"); conv->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -152,10 +152,10 @@ TEST_F(CudnnConvRewriterTest, activations, gradients, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -182,10 +182,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -212,10 +212,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -241,10 +241,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -292,10 +292,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { /*feature_group_count=*/1, conv_window, conv_dnums) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -338,10 +338,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -371,10 +371,10 @@ TEST_F(CudnnConvRewriterTest, default_conv_window_, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -425,10 +425,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -475,10 +475,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -529,10 +529,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -584,10 +584,10 @@ TEST_F(CudnnConvRewriterTest, conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -600,7 +600,8 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { constant_arr.FillIota(0); string constant_str = LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); - ParseAndVerifyModule(absl::StrFormat(R"( + + const string module_str = absl::StrFormat(R"( HloModule test ENTRY entry_computation { @@ -610,10 +611,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, dim_labels=bf01_01oi->bf01, feature_group_count=1 })", - constant_str)); - EXPECT_TRUE(RunPass(&module())); + constant_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); EXPECT_THAT( - module().entry_computation()->root_instruction(), + m->entry_computation()->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, op::Reverse(op::Constant())), 0)); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index 492d290bf4a27a91fa14dea95ac62d90bc1fa28a..3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -138,6 +138,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, const int num_dimensions = window.dimensions_size(); CHECK_LE(num_dimensions, 3); + CHECK_GE(num_dimensions, 1); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. // This matches the behavior of TF (see definition of conv1d in @@ -148,10 +149,15 @@ Status RunCudnnConvImpl(CudnnConvParams params, output_shape.element_type()) << ShapeUtil::HumanString(output_shape); + // If one dimension is reversed, we need to have all dimensions reversed (so + // we're doing convolution not cross correlation). + const bool dims_reversed = window.dimensions()[0].window_reversal(); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dims_reversed, dim.window_reversal()); CHECK_EQ(dim.padding_low(), dim.padding_high()); CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " @@ -198,6 +204,7 @@ Status RunCudnnConvImpl(CudnnConvParams params, ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); convolution_descriptor.set_group_count(feature_group_count); + convolution_descriptor.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -363,14 +370,12 @@ StatusOr GetCudnnConvParams( params.output_shape = &conv_result_shape; params.fusion.emplace(); auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { return InternalError("Bad activation mode: %s", backend_config.ShortDebugString()); } + fusion.mode = static_cast( + backend_config.activation_mode()); fusion.side_input_scale = backend_config.side_input_scale(); params.input_buf = operand_buffers[0]; params.filter_buf = operand_buffers[1]; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 30c1f9088968305ad0207164ecb07ba13cc89ee6..470457935acacb8940af241dadb393d770786939 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -229,7 +229,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - (user->fusion_kind() == HloInstruction::FusionKind::kInput && + (IsReduceInputFusion(*user) && LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 57426327822d95a42f407ed7488f35acfd3623d2..ae2e718db29803a085401969a7d9b09abf690a6c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -51,7 +51,7 @@ GpuExecutable::GpuExecutable( const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 0e276282e40fba0ae4881a51dad0c7c9e8d1c081..2b3c77f5b82aa94f44d8de56caf0f4d31c05e0cb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { GpuExecutable(const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2d31fd5570c468b0c42fa308535fd335f3588a79..392b149abdfb5bf2ce76e8f9f7c4f2cba898ac8c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -55,7 +55,7 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr) { if (instr.IsMultiOutputFusion()) { for (const HloInstruction* operand : instr.fused_expression_root()->operands()) { @@ -67,17 +67,18 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { return true; } } - return false; - } else if (instr.opcode() == HloOpcode::kFusion) { - if (IsReductionToVector(*instr.fused_expression_root())) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) - << " Fusion rooted at reduction-to-vector op must be of kind kInput: " - << instr.ToString(); - return true; - } - return false; + } else if (instr.opcode() == HloOpcode::kFusion && + IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; } - return IsReductionToVector(instr); + return false; +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + return IsReduceInputFusion(instr) || IsReductionToVector(instr); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index f7c24a0d5bbfcc61389ea19ae7f769671e4e974d..c0be354730d22fb76754a60a1c9c58781d0d452a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -33,14 +33,17 @@ namespace gpu { bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, const HloInstruction& reduce); -// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` -// is either an unfused reduction-to-vector op, an input fusion rooted at a -// reduction-to-vector op, or a multi-output input fusion with at least one -// reduction-to-vector op root. // Note that reduction ops are lowered in different ways. Reduce input fusions // are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at // reduction-to-vector ops. Other reduction ops are lowered by // GpuElementalIrEmitter and fused like elementwise ops. + +// Whether `instr` is an input fusion rooted at a reduction-to-vector op or a +// multi-output input fusion with at least one reduction-to-vector op root. +bool IsReduceInputFusion(const HloInstruction& instr); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index d91b7bc61fda5a07c163a07ec0e1644d2ad9db49..12222500ea732a4ca8ea6b3a37033f7e8d4ee927 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -178,7 +178,7 @@ TEST_F(GpuFusibleTest, EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -191,10 +191,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -207,10 +208,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -225,10 +227,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -243,10 +246,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -263,11 +267,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + IsReduceInputFusion_MultiOutputInputReduceFusionWithExtraOutputs) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -284,10 +289,11 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -304,11 +310,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + IsReduceInputFusion_MultiOutputLoopFusionReduceAndElementwiseOp) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -325,6 +332,7 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 02a0d028c118aba23996f9b97d05443bb4a00c88..1126943624a3771433ecac591545d335c1890115 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -37,12 +37,12 @@ class GpuHloOrdering : public PredecessorHloOrdering { public: GpuHloOrdering(const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order); + const std::vector& thunk_launch_order); ~GpuHloOrdering() override = default; // Only the entry computation can possibly be sequentially ordered, and only // if we've assigned all instructions to a single stream. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return &computation == module_->entry_computation() ? entry_sequence_.get() : nullptr; @@ -51,17 +51,17 @@ class GpuHloOrdering : public PredecessorHloOrdering { string ToString() const override { return ToStringHelper("GpuHloOrdering"); } private: - std::unique_ptr> entry_sequence_; + std::unique_ptr entry_sequence_; }; GpuHloOrdering::GpuHloOrdering( const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order) + const std::vector& thunk_launch_order) : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = absl::make_unique>( - thunk_launch_order); + entry_sequence_ = + absl::make_unique(thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering( for (auto* computation : module->computations()) { if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, + HloReachabilityMap::Build(computation)); } } } @@ -149,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering( // However, if the total order is A,B,D,C,E, then C and E can run // concurrently. void BFSLaunchOrder(const HloComputation* computation, - std::vector* launch_order) { + std::vector* launch_order) { // This topological sort uses two data structures: // 1. `incoming_edge_count` which keeps track of the number of incoming // edges to each HLO; @@ -157,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation, // // The sorting algorithm repeatedly pops the top from the queue and deletes // that HLO from the graph, making more HLOs incoming-edge free. - std::deque queue; + std::deque queue; std::unordered_map incoming_edge_count; - for (const auto& hlo : computation->instructions()) { + for (auto* hlo : computation->instructions()) { if (hlo->operand_count() == 0) { queue.push_back(hlo); } else { @@ -171,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation, } while (!queue.empty()) { - const HloInstruction* x = queue.front(); + HloInstruction* x = queue.front(); queue.pop_front(); launch_order->push_back(x); - for (const HloInstruction* y : x->users()) { + for (HloInstruction* y : x->users()) { --incoming_edge_count[y]; if (incoming_edge_count[y] == 0) { queue.push_back(y); @@ -194,14 +195,14 @@ StatusOr> GpuHloSchedule::Build( std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. - const HloComputation* entry_computation = module.entry_computation(); + HloComputation* entry_computation = module.entry_computation(); if (stream_assignment.StreamCount() == 1) { // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( HloInstructionSequence sequence, ScheduleComputation( - *entry_computation, [pointer_size](const BufferValue& buffer) { + entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); schedule->thunk_launch_order_ = sequence.instructions(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 07a7fc67aa555845c3de57e574ab582403ec0490..7f224ffe4f03f8f05b0f1907628d99d9df387770 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -46,7 +46,7 @@ class GpuHloSchedule { // Returns the total order of thunk launches, represented in terms of HLO // instructions. - const std::vector& ThunkLaunchOrder() const { + const std::vector& ThunkLaunchOrder() const { return thunk_launch_order_; } @@ -60,7 +60,7 @@ class GpuHloSchedule { private: GpuHloSchedule(); - std::vector thunk_launch_order_; + std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index b857fa775a76ec999b505a2a64332cc0c54cf00b..91db7151f22fd75b20244878bee86d65acd1d304 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,16 +24,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloVerifiedTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: - using HloVec = std::vector; + using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloVerifiedTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 7d01eeb02567d710e9de089c7f29ffcc5f959f9a..b511155f85fb24adc1828cbef7f3fb60778ef7ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloVerifiedTestBase { +class GpuHloSupportCheckerTest : public HloTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(GpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4822b820f4e229336e2b26cfbd0097c8c31a50c8..2ffc8bfb49b205dced0d540ba72426e72d95e596 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 7f2b59810f0334b24a50fc83b85ab838002afd23..43f43b50e4a6478f343088194871cc9d380bd2d2 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 57e66f5a12cf54824c3139ce2fb32e7cf762b040..2b060b03ceae9bf6947f896dae2987a50972013b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -41,7 +41,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), exp1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -61,7 +61,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), negate1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -80,7 +80,7 @@ TEST_F(InstructionFusionTest, HloInstruction* reshape2 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -99,7 +99,7 @@ TEST_F(InstructionFusionTest, HloInstruction* transpose2 = builder.AddInstruction( HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -723,7 +723,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) { op::Reduce(op::Broadcast(op::Parameter()), op::Constant())); } +TEST_F(InstructionFusionTest, FuseReverse) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY Reverse { + p0 = f32[50,96,1024]{2,1,0} parameter(0) + add = f32[50,96,1024]{2,1,0} add(p0, p0) + ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reverse(op::Add(op::Parameter(), op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 21e44e1e7d3fb7818e114b70025bfb85eacf786a..ebd73f3a9124fbbfeabf3d5041d44a3da0ddd2fb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -65,11 +65,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -88,6 +88,8 @@ limitations under the License. namespace xla { namespace gpu { +using llvm_ir::KernelMappingScheme; + namespace { using absl::InlinedVector; @@ -1188,7 +1190,7 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce), index_ty); } -static std::pair ComputeTilingSchemeForReduction( +static std::pair ComputeKernelMappingSchemeForReduction( int64 depth, int64 width, int64 kWarpSize) { constexpr int64 kTargetNumElementsPerThread = 64; int64 x_tile_size = kTargetNumElementsPerThread; @@ -1322,7 +1324,7 @@ Status IrEmitterUnnested::EmitRowReduction( int64 x_tile_size; int64 z_tile_size; std::tie(x_tile_size, z_tile_size) = - ComputeTilingSchemeForReduction(depth, width, kWarpSize); + ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. @@ -2171,7 +2173,18 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; Shape keys_shape = sort->operand(0)->shape(); + int64 dimension_to_sort = sort->dimensions(0); + // In case there is a 'values' parameter that is a iota, we take note and use + // it later to ensure a stable sort. Otherwise, we don't guarantee a stable + // sort. + int64 iota_values_parameter_index = -1; for (int64 i = 0; i < sort->operand_count(); ++i) { + if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && + ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && + Cast(sort->operand(i))->iota_dimension() == + dimension_to_sort) { + iota_values_parameter_index = i; + } ShapeIndex shape_index = sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the @@ -2196,10 +2209,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } } - int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - auto index_type = b_.getInt64Ty(); + CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); + CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); // Naive C++ code for the outer loops: // @@ -2213,42 +2226,120 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // } // } // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - + // This follows the alternative representation of the algorithm described on + // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter + // + // Each mask specifies how to derive from one position in the array the + // position with which it should be compared (we calculate the xor of the + // position with the mask). + // As an optimization, we can move the 'mask' loop to inside the + // sorting/comparison loop if the comparisons happen within a small block of + // the array. To make this work, we collect all consecutive masks that are + // smaller than our chosen power of 2 tile size, and pass them to SortInPlace. + // Each thread then processes one tile of data. + + const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages); + + // If we cannot combine several xor masks together, we don't use tiling, so we + // calculate the standard launch dimensions for the shape. However we only + // need to iterate through ~half of the dimension to sort (rounded up to the + // next highest power of 2), because each iteration compares one pair of + // elements. + Shape standard_iteration_shape = keys_shape; + uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); + standard_iteration_shape.set_dimensions(dimension_to_sort, + standard_num_iterations_in_sort_dim); + LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( + standard_iteration_shape, ir_emitter_context_->device_description()); + + // Calculate the launch dimensions for the case where we use tiling. We split + // the dimension that should be sorted into tiles of size 'kTileSize'. This + // means we first need to round 'dimension_to_sort_bound' up to be a multiple + // of the tile size. + int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize); + Shape iteration_shape = keys_shape; + + // We iterate through the element pairs that should be compared. + uint64 num_iterations_in_sort_dim = rounded_bound / 2; + iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim); + uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape); + + // For correctness reasons we need exactly 'kTileSize' / 2 many threads per + // block. Each thread is responsible for copying exactly two adjacent elements + // into shared memory, and then does a comparison of two possibly different + // elements taken from shared memory. + const uint64 kThreadsPerBlock = kTileSize / 2; + + // Check whether we should use any tiling. We might not be able to use it if + // we have not enough threads, or not enough shared memory. Also it does not + // give a speedup if the tile size is < 128. + int64 total_shared_memory_needed = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + total_shared_memory_needed += + kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type()); + } + bool no_tiling = + kTileSize < 128 || + kThreadsPerBlock > + ir_emitter_context_->device_description().threads_per_block_limit() || + total_shared_memory_needed > + ir_emitter_context_->device_description().shared_memory_per_block(); + + uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); + LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + + auto emit_kernel = [&](absl::Span xor_masks) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = xor_masks.size() > 1 + ? tiled_launch_dimensions + : standard_launch_dimensions; + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } + return llvm_ir::EmitSortInPlace( + dimension_to_sort, keys_array, values_arrays, + iota_values_parameter_index, IrName(sort), xor_masks, &b_, + launch_dimensions, + xor_masks.size() > 1 ? num_iterations_in_sort_dim + : standard_num_iterations_in_sort_dim, + kTileSize); + }; + std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { for (int64 mask = stage; mask >= 0; --mask) { - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys_shape, ir_emitter_context_->device_description()); - UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), - ir_emitter_context_->llvm_module()); - - llvm::Value* xor_mask; + int64 xor_mask; if (mask == stage) { - xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + xor_mask = (1LL << (stage + 1)) - 1; } else { - xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + xor_mask = 1LL << mask; } - - IrArray keys_array; - std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + if (xor_mask >= kTileSize || no_tiling) { + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + xor_masks.clear(); } + TF_RETURN_IF_ERROR(emit_kernel({xor_mask})); + } else { + xor_masks.push_back(xor_mask); } - TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_mask, - &b_, &launch_dimensions)); } } + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + } AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); @@ -3068,31 +3159,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } -int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_in_reduced_shape_arrays->reserve(num_outputs); - output_reduced_shapes->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), - reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); - } - } else { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - hlo.shape().element_type(), reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); - } - return num_outputs; -} int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, @@ -3152,308 +3218,525 @@ llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, "block.id.x"); } -// Emits code to process up to (tile_size/num_rows) elements in a tile, given -// `emit_elem_function` is the function to emit code to process one element, `y` -// and `x` are the coordinates for the first element to process, and `index` is -// the index for the origin of the tile. Emits bounds check to ensure that each -// processed element is within the boundary defined by `tile_width` and -// `tile_height`. +void EmitFullTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + for (int64 i = 0; i < tile_size_y; i += num_threads_y) { + IrArray::Index source_idx_y = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), + KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + } +} + +void EmitPartialTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + + ksl->IfReturnVoid( + "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + // tile_height_bound = + // ceil(tile_height / num_threads_y) * num_threads_y + llvm::Value* ceiling_of_ratio = builder->CreateUDiv( + builder->CreateAdd(tile_height, llvm::ConstantInt::get( + index_ty, num_threads_y - 1)), + llvm::ConstantInt::get(index_ty, num_threads_y)); + llvm::Value* tile_height_bound = builder->CreateMul( + ceiling_of_ratio, + llvm::ConstantInt::get(index_ty, num_threads_y)); + ksl->ForReturnVoid( + loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/tile_height_bound, + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), + [&] { + emit_elem_function( + source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc); + }); + }); + }); + } +} + +// Emits code to process up to +// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, +// given `emit_elem_function` is the function to emit code to process one +// element, `y` and `x` are the intra-tile coordinates for the first element +// to process, and `index` is the index for the origin of the tile. Information +// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits +// bounds check to ensure that each processed element is within the boundary +// defined by `tile_width` and `tile_height`. void EmitTiledElementalCodeWithBoundsCheck( - int64 tile_size, int64 num_rows, const IrArray::Index& index, - const string& loop_name, KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - // Emits a constant value with index type. - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - - auto emit_full_tile = [&] { - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(index_typed_constant(i), y); - emit_elem_function(source_idx, y_loc); - } - }; - auto emit_last_row = [&] { - ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { - // tile_height_upper_bound = - // ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, - index_typed_constant(num_rows - 1)), - index_typed_constant(num_rows)), - index_typed_constant(num_rows)); - ksl->ForReturnVoid( - loop_name, /*start=*/index_typed_constant(0), - /*end=*/tile_height_upper_bound, - /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { - auto y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), - y_loc); - }); - }); - }); - }; ksl->IfReturnVoid( "full_tile", builder->CreateAnd( - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), - emit_full_tile, emit_last_row); + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), + tile_width), + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), + tile_height)), + [&] { + EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, + emit_elem_function); + }, + [&] { + EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, tile_height, tile_width, index_ty, + emit_elem_function); + }); } } // namespace -// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose -// algorithm to improve the memory access patterns for the input parameters -// which have a shape that is a 0-2-1 transpose of the output tensors. -// -// For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape of -// three components 0-1-2 in the order major to minor. The x- and y- dimensions -// of the tensors are tiled in square tiles of edge length `kTileSize`. Each -// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each -// thread copies kTileSize/kNumRows elements from the input to a shared memory -// tile, then the otherwise "regular hlo kernel" reads from the shared memory -// instead of the original input. -// -// This is similar to the following CUDA algorithm in TensorFlow: -// https://goo.gl/MStRV6. +// Emits code to process a tensor element in a tile for the given kCopy HLO that +// performs a 0-2-1 transpose. // -// `kTileSize` should usually be same as warp size. We currently choose 32 for -// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// index: The index for the first output element in the normalized tensor. The +// normalized tensor is the resulting tensor after collapsing contiguous +// dimensions that play the same role in the transpose. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +// kernel_info: Other information to support the kernel code generation. +void IrEmitterUnnested::EmitTileElementForCopy( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(tiled_param_info->GetBufferForParameter(0), + {b_.getInt64(0), x_loc, y_loc}), + "output_element"); + llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); + Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( + hlo->shape().element_type(), + kernel_info->GetKernelMappingScheme()->GetDimensionsInElements()); + // When the output_reduced_shape is a 0-2-1 transpose of the input shape, + // the 0-2-1 transpose is achieved through EmitWriteArrayElement. + output_array.CastToShape(output_reduced_shape, &b_) + .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_); +} + +// Emits code to process a tensor element in a tile for the given kLoop fusion +// HLO containing parameters that are 0-2-1 transpose of its outputs. // -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles. -LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, absl::Span reduced_output_dims, - absl::Span tiled_param_ids) { - // Parameters for the tiling algorithm. - constexpr int64 kTileSize = 32; - constexpr int64 kNumRows = 4; - constexpr int64 kThreadsPerTile = kTileSize * kNumRows; - - // Construct IrArrays for the inputs and outputs. +// index: The index for the first output element in the normalized tensor, that +// is the resulting tensor after collapsing contiguous dimensions that play +// the same role in the transpose. +// kernel_info: Other information to support the kernel code generation. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +void IrEmitterUnnested::EmitTileElementForFusion( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); - int64 num_outputs = output_arrays.size(); - std::vector param_arrays = ConstructIrArrayForInputs(*hlo); - int64 num_params = param_arrays.size(); + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elem_emitter); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = + kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, output_arrays[0].GetShape()); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + DCHECK(output_value->getType()->isStructTy()); + DCHECK_EQ(output_value->getType()->getStructNumElements(), + output_arrays.size()); + for (int64 i = 0; i < output_arrays.size(); ++i) { + output_arrays[i].EmitWriteArrayElement( + untiled_index, ExtractValue(output_value, i), &b_); + } + } else { + output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_); + } +} + +// Emits a block of tiles, given a function object to emit one tile. +void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, + llvm::Type* index_ty) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); + absl::Span dims_in_block = + mapping_scheme->GetDimensionsInBlocks(); + absl::Span block_sizes = mapping_scheme->GetBlockSizes(); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Emit all the tiles for a given dimension in a tile block. + auto emit_tiles_for_block_dim = + [&](const string& loop_name, const IrArray::Index& starting_tile, + int dim_id, + const std::function + emit_next_block_dim) { + if (block_sizes[dim_id] == 1) { + emit_next_block_dim(starting_tile); + } else { + llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id]; + llvm::Value* block_size_for_dim = + index_typed_constant(block_sizes[dim_id]); + llvm::Value* block_id_for_dim = + b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); + llvm::Value* last_block_for_dim = + index_typed_constant(dims_in_block[dim_id] - 1); + llvm::Value* last_block_size_for_dim = index_typed_constant( + dims_in_tile[dim_id] - + (dims_in_block[dim_id] - 1) * block_sizes[dim_id]); + llvm::Value* num_tiles_in_block = + Select(ICmpEQ(last_block_for_dim, block_id_for_dim), + last_block_size_for_dim, block_size_for_dim); + + ksl.ForReturnVoid( + loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); + } + }; + + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + const bool block_contains_multi_tiles = + mapping_scheme->GetNumberOfTilesInOneBlock() > 1; + + // Emit the tile with a given tile_index, by calculating the tight bounds for + // each dimension of the tile and then calling emit_one_tile. + auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { + std::vector output_tile_bounds(3); + for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; + ++i) { + int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i); + // Only last row or column may not have full size. + llvm::Value* is_last_row = + ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1)); + int64 partial_row_size = + reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; + output_tile_bounds[i] = + Select(is_last_row, index_typed_constant(partial_row_size), + index_typed_constant(tile_size_for_dim), "tile_bound"); + } + + IrArray::Index tile_origin = + mapping_scheme->GetElementIndexForTileOrigin(tile_index); + emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles); + }; + const IrArray::Index starting_block = + mapping_scheme->EmitBlockIndex(index_ty); + const IrArray::Index starting_tile_for_dim_z = + mapping_scheme->GetTileIndexForBlockOrigin(starting_block); + + // Emit the three dimensional block of tiles. + emit_tiles_for_block_dim( + "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ, + [&](const IrArray::Index& starting_tile_for_dim_y) { + emit_tiles_for_block_dim( + "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY, + [&](const IrArray::Index& starting_tile_for_dim_x) { + emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x, + KernelMappingScheme::DimX, + emit_one_tile_for_tile_index); + }); + }); +} + +// Emits a kernel for the hlo instruction using the given kernel mapping scheme. +// +// unnested_hlo: The unnested hlo instruction for which the kernel is generated. +// Currently, these hlo instructions are supported: kLoop fusion, kCopy. +// tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of +// other tensors with the same dimensions and need to be tiled and tranposed. +// mapping_scheme: The tiling scheme to use. +// kernel_generator: Contains function objects for code generation, such as +// element generator, block prologue and epilogue generators. +// kernel_info: Represent other information to support the code generation +// of the tiled kernel for the hlo. +LaunchDimensions IrEmitterUnnested::EmitKernel( + HloInstruction* unnested_hlo, absl::Span tiled_param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + + std::vector param_arrays = ConstructIrArrayForInputs(*unnested_hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); for (int64 id : tiled_param_ids) { - const HloInstruction* param = hlo->operand(id); - // Add 1 to the minor dimension to reduce shared memory bank conflicts. - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - param->shape().element_type(), module_), - kTileSize + 1), - kTileSize); - const int kNVPTXSharedMemoryAddrSpace = 3; - auto* tile_base_ptr = new llvm::GlobalVariable( - *b_.GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), - llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, - llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); - param_shmem_buffers[id] = tile_base_ptr; + const HloInstruction* param = unnested_hlo->operand(id); + param_shmem_buffers[id] = + mapping_scheme->GetSharedMemoryBufferForElementType( + llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), + module_), + IrName(unnested_hlo, StrCat("tile", id))); VLOG(3) << "Added shmem buffer for parameter " << id << ": " - << llvm_ir::DumpToString(*tile_base_ptr); - } - - // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result - // for the purpose of tiling. Calculate the logical output dimensions in the - // tile from the reduced output dimensions. - std::vector output_dims_in_tiles = std::vector( - reduced_output_dims.begin(), reduced_output_dims.end()); - CHECK_EQ(output_dims_in_tiles.size(), 3); - for (int i = 1; i < 3; ++i) { - output_dims_in_tiles[i] = - CeilOfRatio(output_dims_in_tiles[i], kTileSize); + << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - const int64 num_tiles = - absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); - LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); - llvm::Type* index_ty = - GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); + LaunchDimensions launch_dimensions = LaunchDimensions( + mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - // Cast each output IrArray to its corresponding reduced shape and keep the - // reduced shape live during IR emission. - std::vector output_in_reduced_shape_arrays; - std::vector output_reduced_shapes; - CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, - &output_in_reduced_shape_arrays), - num_outputs); - // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; std::vector param_reduced_shapes; - CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( - *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, - ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), - num_params); + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape( + *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays); + DCHECK_EQ(num_shapes, num_params); // Calculate the starting element coordinate within a tile for the current // thread, (y, x) from thread_id. llvm::Value* x; llvm::Value* y; - std::tie(y, x) = CalculateYXCoordinateWithinTile( - &b_, index_typed_constant(kTileSize), kThreadsPerTile); - - // Calculate the index for the current output tile from block_id. - const IrArray::Index output_tile_index( - GetBlockIdx(&b_, index_ty, num_tiles), - ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, - output_dims_in_tiles), - &b_); - - // Output tile origin is the index for the first element of the current output - // tile. - const IrArray::Index output_tile_origin = [&] { - IrArray::Index index = output_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); - } - return index; - }(); + std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty); - // Calculate the input tile origin from the output tile origin. - const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim())); - - // Calculate the current output tile bounds in each of the logical dimensions. - std::vector output_tile_bounds(3); - for (int i = 1; i < 3; ++i) { - // Only last row or column may not have full size. - output_tile_bounds[i] = - Select(ICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); - } + kernel_info->SetLaneId( + mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x + : nullptr); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { - EmitTiledElementalCodeWithBoundsCheck( - kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, - tile_height, emit_elem_function); + llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, + &ksl, &b_, y, x, tile_height, + tile_width, emit_elem_function); }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = Add(index[dim], addend); - return index; - }; - const IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); - - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); - } - }); + auto emit_one_tile = [&](const IrArray::Index& output_tile_origin, + absl::Span output_tile_bounds, + bool block_contains_multi_tiles) { + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + const IrArray::Index input_index = + input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + // If shared memory transpose is needed, wait for all threads to reach this + // point, lest we copy a value from tile to output before the other thread + // copies it from input to tile. This is `__syncthreads` in CUDA. + if (!tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } - llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + const IrArray::Index output_index = + output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values for - // the tiled parameters are read from the shmem buffers. - if (hlo->opcode() == HloOpcode::kCopy) { + // Write to output[index] by emitting code like normal, except that values + // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = - Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, load_from_shmem_buffer, &b_); - }); - } else { - CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elem_emitter); - tiled_param_info.set_y(y_loc); - fused_emitter.SetTiledParameterInfo(&tiled_param_info); - TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); - IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( - index, output_reduced_shapes[0], output_arrays[0].GetShape(), - &b_); - const llvm_ir::ElementGenerator& output_generator = - fused_emitter.GetRootGenerator(); - llvm::Value* output_value = - output_generator(untiled_index).ValueOrDie(); - if (hlo->IsMultiOutputFusion()) { - CHECK(output_value->getType()->isStructTy()); - CHECK_EQ(output_value->getType()->getStructNumElements(), - output_in_reduced_shape_arrays.size()); - for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { - output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, ExtractValue(output_value, i), &b_); - } - } else { - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, output_value, &b_); - } + output_index, "output", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + kernel_generator.GetTileElementGenerator()(unnested_hlo, index, + kernel_info, y_loc, x_loc); }); + // If a tile block contains multiple tiles and shared memory buffers are + // used, we need to wait for all threads to finish using the shared memory + // buffer for the current tile before we move on to process the next tile + // and overwrite the shared memory buffers. + if (block_contains_multi_tiles && !tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } + }; + + const BlockPrologueGenerator& block_prologue_generator = + kernel_generator.GetBlockPrologueGenerator(); + if (block_prologue_generator) { + block_prologue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with all the individual outputs. - if (hlo->IsMultiOutputFusion()) { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); + EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + + const BlockEpilogueGenerator& block_epilogue_generator = + kernel_generator.GetBlockEpilogueGenerator(); + if (block_epilogue_generator) { + block_epilogue_generator(unnested_hlo, kernel_info); + } + + // For multioutput fusion, emit a tuple with pointers to all the individual + // outputs. + if (unnested_hlo->IsMultiOutputFusion()) { + std::vector output_arrays = + ConstructIrArrayForOutputs(*unnested_hlo); + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), output_arrays, + &b_, module_); } return launch_dimensions; } +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// with a shape that is a 0-2-1 transpose of the output tensor shape. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape +// of three components 0-1-2 in the order major to minor. The x- and y- +// dimensions of the tensors are tiled in square tiles with an edge length +// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads +// transposes one tile: each thread copies kTileSize/kNumRows elements from +// the input to a shared memory tile, then the otherwise "regular HLO kernel" +// reads from the shared memory instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more +// efficient to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { + constexpr int kNumRows = 4; + KernelMappingScheme mapping_scheme( + reduced_output_dims, /*tile_size_y=*/kWarpSize, + /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1}, + /*num_threads_y=*/kNumRows, + /*num_threads_x=*/kWarpSize, &b_); + TileElementGenerator element_generator; + if (hlo->opcode() == HloOpcode::kCopy) { + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + }; + } else { + DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); + }; + } + KernelCodegenInfo kernel_info(&mapping_scheme); + KernelCodeGenerator kernel_generator(std::move(element_generator)); + return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info); +} + +namespace { +// Returns true to indicate it is safe to use the tile based shared memory +// transpose implementation to implement the kernel for the instruction. +// +// An instruction is not safe for such an implementation if it can change the +// element order of a tensor without changing the dimension of the tensor, and +// the instruction has a corresponding elemental_ir_emitter. +bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { + auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { + HloOpcode opcode = instr->opcode(); + CHECK_NE(opcode, HloOpcode::kFusion); + return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); + }; + + if (hlo->opcode() == HloOpcode::kFusion) { + return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), + is_safe_for_tile_based_transpose); + } + + return is_safe_for_tile_based_transpose(hlo); +} +} // namespace + bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); @@ -3465,8 +3748,8 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { ? ShapeUtil::GetSubshape(hlo->shape(), {0}) : hlo->shape(); - // If the output_shape is reduced to 021 shape, find all the parameters of the - // hlo that are in the corresponding 012 shape. + // If the output_shape is reduced to 021 shape, find all the parameters of + // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); @@ -3498,10 +3781,14 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } + if (!IsInstructionSafeForTileBasedTranspose(hlo)) { + return false; + } + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the - // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb - // shared memory per SM. (This is increased to 96kb in Volta, but we don't - // use this, in part because it eats into our L1 cache space.) + // elements are of size 4 bytes), and CUDA has an architectural limit of + // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we + // don't use this, in part because it eats into our L1 cache space.) // // For correctness we need to ensure that we don't make more than 48kb worth // of shmem tiles per block. And for performance, we'd probably like to use @@ -3509,9 +3796,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // gpu core. // // We say without benchmarks that we want at least 3 threads/block, - // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose - // which params get the shmem transpose treatment arbitrarily; it's not clear - // if there's a Right Choice. + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We + // choose which params get the shmem transpose treatment arbitrarily; it's + // not clear if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use // shared memory in fusions. If in the future other fusible ops use shared @@ -3565,10 +3852,10 @@ Status IrEmitterUnnested::EmitConstantGlobals() { } // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in the - // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely - // preserves their names (like available_externally), we also need to ensure - // that they stick around even if they're "unused". + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". // // We may have to be more more clever here in the future if we notice that // we're keeping around too many globals because of their linkage. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 334c0b3c20b0888fa9b167a8979221f0184a82e7..97a1e10455336cd4842275b6cf1482614bfbfa60 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { @@ -47,6 +48,94 @@ namespace gpu { // class IrEmitterUnnested : public IrEmitter { public: + // Parameter block_contains_multi_tiles indicates whether a tile block + // consists of multiple tiles or not. If the tile block contains only one + // tile, there is no need to use atomic operation to accumulate a local result + // to a global result to implement reduction. + using TileGenerator = + std::function output_tile_bounds, + bool block_contains_multi_tiles)>; + // KernelCodegenInfo records the common information to support the code + // generation for a kernel to process tensor elements by blocks. A block of + // tensor elements may contain one or multiple tiles. The code generators that + // generate code for tile elements or block prologue/epilogue refer to this + // class in their prototypes. If the implementations of such code generators + // require other information that are specific to the HLO instructions, the + // implementations need to define and use derived classes of this class. + class KernelCodegenInfo { + public: + explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) + : mapping_scheme_(mapping_scheme), + tiled_param_info_(nullptr), + lane_id_(nullptr) {} + + void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { + CHECK_EQ(tiled_param_info_, nullptr); + tiled_param_info_ = tiled_param_info; + } + + llvm::Value* GetLaneId() const { return lane_id_; } + llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { + return mapping_scheme_; + } + llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { + return tiled_param_info_; + } + + private: + llvm_ir::KernelMappingScheme* mapping_scheme_; + llvm_ir::TiledParameterInfo* tiled_param_info_; + llvm::Value* lane_id_; + }; + + // A function object to prepare for the code generation for a tile block. + using BlockPrologueGenerator = + std::function; + // A function object to finalize the code generation for a tile block. + using BlockEpilogueGenerator = + std::function; + // A function object to generate code to process one element in a tile. + // + // hlo: the instruction for which the code is generated for. + // index: the index for the first output element of the current thread. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. + // kernel_info: Other information to support the kernel code generation. + using TileElementGenerator = std::function; + + // KernelCodeGenerator records the code generator objects that generate code + // for tile elements or tile block prologue/epilogue. + class KernelCodeGenerator { + public: + explicit KernelCodeGenerator( + TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator = {}, + BlockEpilogueGenerator block_epilogue_generator = {}) + : tile_element_generator_(std::move(tile_element_generator)), + block_prologue_generator_(std::move(block_prologue_generator)), + block_epilogue_generator_(std::move(block_epilogue_generator)) {} + + const TileElementGenerator& GetTileElementGenerator() const { + return tile_element_generator_; + } + const BlockPrologueGenerator& GetBlockPrologueGenerator() const { + return block_prologue_generator_; + } + const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { + return block_epilogue_generator_; + } + + private: + TileElementGenerator tile_element_generator_; + BlockPrologueGenerator block_prologue_generator_; + BlockEpilogueGenerator block_epilogue_generator_; + }; + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context); @@ -205,22 +294,32 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids); + // Emits a kernel for an unnested HLO instruction. + LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, + absl::Span param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info); + void EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, llvm::Type* index_ty); + // Emits code to process a tensor element in a tile for the given kCopy HLO + // that performs a 0-2-1 transpose. + void EmitTileElementForCopy(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given kLoop fusion + // HLO containing parameters that are 0-2-1 transpose of its outputs. + void EmitTileElementForFusion(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. std::vector ConstructIrArrayForInputs( const HloInstruction& hlo); - // For each output of the `hlo` instruction, constructs the reduced shape for - // the output with the given `reduced_output_dims` and cast the original - // output IrArray element in `output_arrays` to the reduced shape. Returns - // the number of outputs. - int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, - const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in // `param_buffers` to find out whether the input has a reduced shape. If the // input has a reduced shape, constructs the reduced shape for the input and diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2..364f69a69d47644b383af9cf6865c93360b82bab 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -453,12 +453,12 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // * 3-6 gives similar results as 2; // * >6 start hurting the performance of at least dot product kernels. // - // TODO(jingyue): The current threshold only considers the numbr of IR + // TODO(jingyue): The current threshold only considers the number of IR // instructions which do not accurately reflect the true cost. We need a // better cost model. FeedLLVMWithFlags({"-bonus-inst-threshold=2"}); - // TODO(b/22073864): Increase limit when scan memory dependency. - // This helps to reduce more redundant load instructions. + // Increase limit when scanning memory dependencies. This helps to reduce + // more redundant load instructions. // // The specific value is currently large enough for s3d in shoc benchmark, // which contains a lot of load instructions and many arithmetic instructions diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 9427d3d54addc7d794ddc0a8f4c45b39b248bc5f..d9b06828e2b5d334873c88cb49c2e0d5675bb5fe 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -140,6 +140,18 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return false; } + // The emitter only supports in-place DUS for fusions with a single DUS at the + // root. Don't sibling fuse DUS for now. + // TODO(b/119178699): Multi-output fusing DUS can improve performance if we + // share the input and output buffers and add support to the emitter. + if (instr1->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice || + (instr2->opcode() == HloOpcode::kFusion && + instr2->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice)) { + return false; + } + // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 1d4856e0cae163bbd9ab741917b85792097d8512..d16c87ba5c63aa582753fe949e9e39ee2d8b81e5 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); @@ -621,5 +621,39 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { } } +TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { + auto module = ParseHloString(R"(HloModule dus_mof + fusion.1 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + fusion.2 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + ENTRY entry { + p.00 = f16[50,96,1024]{2,1,0} parameter(0) + p.01 = f16[50,96,1024]{2,1,0} parameter(1) + p.1 = s32[1]{0} parameter(2) + p.2 = f16[1,96,1024]{2,1,0} parameter(3) + + f1 = f16[50,96,1024] fusion(p.00, p.1, p.2), kind=kLoop, calls=fusion.1 + f2 = f16[50,96,1024] fusion(p.01, p.1, p.2), kind=kLoop, calls=fusion.2 + ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2) + })") + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index de04ed85c30717f5be7c5485ff3b68270c8ec188..637b861f70235f17e8e739907a3f262b7004ee7c 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -142,6 +143,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); + pipeline.AddPass(); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); @@ -177,9 +179,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass(); - pass.AddPass( - /*is_layout_sensitive=*/false, + AlgebraicSimplifierOptions options( [](const Shape&, const Shape&) { return false; }); + options.set_enable_permutation_sort_replacement(true); + pass.AddPass(options); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -248,11 +251,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, + AlgebraicSimplifierOptions options( /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_permutation_sort_replacement(true); + pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. // @@ -810,7 +815,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, // binaries are not available. We don't want to spam logs with // identical warnings in this case. - // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N // for more general usage. static std::atomic warning_done(false); log_warning = !warning_done.exchange(true); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5b6cf2c04d05378a363232e33a6df6432cd6848e..4775baf44aecfe6adaf2bf0d2791595436635b16 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -122,7 +122,7 @@ std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = - computation.ComputeReachability(); + HloReachabilityMap::Build(&computation); std::vector seen_gemms; // The execution of different RNG Hlo instructions in the same module updates // a common global variable. To avoid a race condition, we simply assign all diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index c4f43cc9a614283acb376b5f98e4976615b590ad..31a5d7a8c04e9863830e2026fc73cd7ded8c322e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,16 +21,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloVerifiedTestBase { +class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index ed46f08d5970d479db33a7b9ad416a1480535764..d798b31643782eb25bba08227e29903ec0e7a597 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -37,7 +37,7 @@ cc_library( hdrs = ["gpu_codegen_test.h"], tags = tf_cuda_tests_tags(), deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 79e77d4c4d649020cf52ac25c220c3f90e8469b9..9e3ff8750b88d08bcbc1aae3faead5aecfa19848 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace gpu { -std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { +std::unique_ptr GpuCodegenTest::CreateNewUnverifiedModuleWithFTZ( + bool ftz) { HloModuleConfig config; - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_gpu_ftz(ftz); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index e4a3573babb7ed746504c1466f85b582aa4d044f..d917320e36363c4fa7e4c0055e8f3345cbc610a2 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,9 +26,9 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz - // option. - std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + // Like HloTestBase::CreateNewVerifiedModule(), with a flag for configuring + // the ftz option. + std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); // Compiles the given HLO module to PTX and verifies the PTX matches the given // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 780539c164277f14c2bd964024f7c3ca179f4ada..a1ed8499040359fe7265a7317b0577a990a2234c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index 177b94934c7f519172508b5cc6e088f908401193..d0ccd8619bde9ddd560989380b403efed5c5f42c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -39,7 +39,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/1, param_shape, "y")); builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } @@ -54,7 +54,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/0, param_shape, "x")); builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index a06576df7b874745236a8d9075355a01ec42e777..6814be779e0b02c38e3bc7008f036b845d88cb6f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 15d1e269cc22b88f5269175084f20600f165011c..a302b582ede3723acd118d2e4a4bb3efdf7a4d0b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -193,6 +193,33 @@ TEST_F(GpuKernelTilingTest, /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { + const char *const kHloString = R"( + HloModule FusionTransposeWithReverseNotTiled + fused_computation.1 { + arg0 = f32[128,64]{1,0} parameter(0) + copy0 = f32[128,64]{0,1} copy(arg0) + ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0} + } + + ENTRY reverse_break_assumption { + param0 = f32[128,64]{1,0} parameter(0) + ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9..3019215c015a4e0aa094a62424d650ced0de2a0e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 15198865bda98f9718342d5a444a20305f923b48..ca0a78034d7dc83d17ad72202914d95f37ac122b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8..4636f1d9d20b8c213ffadec427b3820a89c68a7f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -85,7 +85,7 @@ TEST_F(GpuUnrollingTest, UnrollFourTimes) { TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { // The default unrolling factor is 4. HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 141f3219387940a08ef22cbcc0be0971a14c2cd6..6b2d76764a077dc6cfa3f9ddc6e525ab330323be 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { std::unordered_map hlo_to_thunk; @@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule( InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } - for (const HloInstruction* hlo : hlo_total_order) { + for (HloInstruction* hlo : hlo_total_order) { if (hlo_to_thunk.count(hlo)) { thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); } diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index d3352994f845a535233612a17e19107511ce0622..43b628a1baf0e79a3197f3cfad3547991642eaed 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -46,7 +46,7 @@ class ThunkSchedule { public: ThunkSchedule(std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + const std::vector& hlo_total_order); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc index 5fa9e91050a85b67eb22a48d47e4dd157a53c699..3d00ac4dc7b57664a317157c093d7ffaa01b4fd6 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -32,7 +32,7 @@ namespace gpu { namespace { using match::Concatenate; -class VariadicOpSplitterTest : public HloVerifiedTestBase {}; +class VariadicOpSplitterTest : public HloTestBase {}; TEST_F(VariadicOpSplitterTest, DontSplit) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 926b59a1b854bd3d7d2699124e10b70147e52e2a..2dce7749bbd8da2673ae607eee3d731d9917e8fe 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewModule()), + : module_(CreateNewVerifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index e30e7667f3015bc7bfe67c65147a5016332780f7..dc40b9446ad1bffcb757543e52fc9ab20de6d52e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,16 +30,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; +class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module); + HloSchedule schedule(module.get()); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -258,7 +258,7 @@ class HeapSimulatorTracker { // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, - const std::vector& instruction_sequence) { + const std::vector& instruction_sequence) { HloModuleConfig config; module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); @@ -286,7 +286,7 @@ class HeapSimulatorTracker { // Similar to the single entry computation constructor above, but runs the // simulation over the entire module. void RunWholeModule( - const std::vector& full_module_sequence) { + const std::vector& full_module_sequence) { points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -294,7 +294,7 @@ class HeapSimulatorTracker { HloSchedule schedule(module_.get()); absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { - const HloInstruction* instruction = full_module_sequence[i]; + HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; @@ -351,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloVerifiedTestBase { +class HeapSimulatorTest : public HloTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index dbab62f847e8ca5e0b46dfd4162a0f4222640252..913d4c34b43087d322634dbc436f2f7c5666c77a 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -251,6 +251,41 @@ message HloInputOutputAliasProto { repeated AliasEntryProto entries = 1; } +message DynamicParameterBindingProto { + // A list of bindings which indicates that the `target_dim_num` in + // the subshape `target_param_index` of parameter `target_param_num` + // is a dynamic dimension and its real dynamic size is represented + // by `dynamic_param_index` in parameter `dynamic_param_num`. + // + // As an example, imagine we have a program: + // + // ENTRY main { + // a = f32[] parameter(0) + // b = f32[10] parameter(1) + // ROOT root = (f32[], f32[10]) tuple(%a, %b) + // } + // + // Let's say 'b' (param index 1) is a dynamic shape whose input has + // an upperbound of 10 and real size is determined at runtime.'a' + // represents the real size of b's first dimension. + // + // In this case, the fields are set in the following way: + // dynamic_param_num = 1 + // dynamic_param_index = {} + // target_param_num = 0 + // target_param_index = {} + // target_param_dim = 0 + message Binding { + int64 dynamic_param_num = 1; + repeated int64 dynamic_param_index = 2; + int64 target_param_num = 3; + repeated int64 target_param_index = 4; + int64 target_param_dim_num = 5; + } + + repeated Binding entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -272,6 +307,8 @@ message HloModuleProto { // Describes alias information between inputs and outputs. HloInputOutputAliasProto input_output_alias = 8; + + DynamicParameterBindingProto dynamic_parameter_binding = 9; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 5c8d97b2d15e15d15cb8014a7d25b37437ce8aec..7e6150e94153cd15463725e862ce1b8593f2c991 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,17 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloVerifiedTestBase { +class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : HloVerifiedTestBase() { - module_ = CreateNewModule(); + HloAliasAnalysisTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_, + analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -93,7 +93,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -110,7 +110,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { return false; } - HloModule* module_; + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -638,7 +638,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_).status()); + TF_ASSERT_OK(flattener.Run(module_.get()).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1012,7 +1012,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -1054,13 +1054,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - HloSchedule schedule(module_); + HloSchedule schedule(module_.get()); schedule.set_sequence(body, {body_param, body_root}); schedule.set_sequence(condition, {cond_param, cond_root}); { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b0f7cd91ad1db0a59c09cfbfc1885813dc57e01e..65bd251dd8642314e62dffc118e30e62de1844e4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,7 +322,7 @@ void HloComputation::ComputeInstructionPostOrder( // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. + // result for things like HLO stringification. const auto& operands = current->operands(); for (int64 i = operands.size() - 1; i >= 0; --i) { dfs_stack.emplace_back(operands[i]); @@ -739,72 +740,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -std::unique_ptr HloComputation::ComputeReachability() - const { - const auto& all = MakeInstructionPostOrder(); - auto result = absl::make_unique(all); - auto channel_dependency_map = ComputeChannelDependencies(); - - std::vector inputs; - for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); - - switch (hlo->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - break; - } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - } - break; - } - default: - break; - } - - result->FastSetReachabilityToUnion(inputs, hlo); - } - return result; -} - -void HloComputation::UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map) { - std::queue worklist; - worklist.push(instruction); - - std::vector inputs; - - while (!worklist.empty()) { - const HloInstruction* item = worklist.front(); - worklist.pop(); - - inputs.assign(item->operands().begin(), item->operands().end()); - inputs.insert(inputs.end(), item->control_predecessors().begin(), - item->control_predecessors().end()); - - if (reachability_map->SetReachabilityToUnion(inputs, item)) { - // Add immediate successors to worklist. - for (const HloInstruction* user : item->users()) { - worklist.push(user); - } - for (const HloInstruction* succ : item->control_successors()) { - worklist.push(succ); - } - } - } -} - std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; for (auto* instruction : instructions()) { @@ -860,7 +795,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -890,9 +825,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, const std::vector&) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, const std::vector&) const; Status HloComputation::Accept( const std::function& visitor_func) { @@ -911,14 +846,46 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - /*extras=*/{}, context, suffix); + context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + replacements.emplace(std::move(r3)); + return CloneWithReplacements(std::move(replacements), context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context, - const string& suffix) { + HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -939,18 +906,50 @@ std::unique_ptr HloComputation::CloneWithReplacements( }; VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + + // We want to do a postorder walk over [replace(i) for i in instructions_]. + // We can't reuse MakeInstructionPostOrder() for this, because that will + // generate a postorder of plain instructions_, and our replacements may + // change the postorder! + // + // The postorder we want here is simpler than what MakeInstructionPostOrder() + // does -- we only care about operand dependencies -- so let's just do it + // ourselves. std::vector postorder; - for (HloInstruction* instr : extras) { - postorder.push_back(instr); - } - for (HloInstruction* instr : MakeInstructionPostOrder()) { - if (HloInstruction* replacement = replace(instr)) { - postorder.push_back(replacement); + absl::flat_hash_map visited; + for (const auto& instr : instructions_) { + std::vector dfs_stack; + HloInstruction* new_instr = replace(instr.get()); + if (!new_instr) { + continue; + } + dfs_stack.push_back(new_instr); + + while (!dfs_stack.empty()) { + auto* cur = dfs_stack.back(); + auto it = visited.find(cur); + if (it != visited.end()) { + dfs_stack.pop_back(); + if (it->second == kVisited) { + continue; + } + CHECK_EQ(it->second, kVisiting); + postorder.push_back(cur); + it->second = kVisited; + continue; + } + + visited.insert({cur, kVisiting}); + for (HloInstruction* operand : cur->operands()) { + HloInstruction* new_operand = replace(operand); + if (new_operand) { + dfs_stack.emplace_back(new_operand); + } + } } } std::vector> instructions; - std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { @@ -960,9 +959,8 @@ std::unique_ptr HloComputation::CloneWithReplacements( << operand->ToString() << ", used by " << instr->ToString(); new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, context); - instructions.push_back(std::move(new_instr)); + instructions.push_back( + instr->CloneWithNewOperands(instr->shape(), new_operands, context)); } Builder builder(name() + "." + suffix); for (auto& instr : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dec96d11a93cf56d3c40a6bb7882ffb7336aeeb0..be1ce336968504b6406c9ef4b879821821c5b187 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" @@ -215,19 +214,6 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - std::unique_ptr ComputeReachability() const; - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this @@ -315,7 +301,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); @@ -333,14 +319,38 @@ class HloComputation { // the map's value to replace that instruction in the cloned computation. // // If replacements maps a key to nullptr, we remove that instruction from the - // new computation. - // If additional instructions are used by instructions in replacement map, - // they must be passed in post-order in the extras span. + // new computation. If an element of `replacements` references an instruction + // that's not already in the computation, it's cloned and added to the new + // computation. + // + // All relevant instructions are cloned, *including* unique_ptr in the + // `replacements` map. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); + + // Convenience overloads for CloneWithReplacements. You want to do + // + // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR + // + // but that doesn't work because std::initializer_list is not movable. These + // overloads let you do + // + // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK + // + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -355,6 +365,14 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. bool HasSideEffect() const; @@ -410,14 +428,6 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = - absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; - enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 2aaaef1d36d58bcce18db4aa37ff05ea352e484b..8b50cfa9aed90091cfbedc1df902440ec9bf2a80 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -151,7 +151,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,7 +257,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -274,7 +274,7 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -376,7 +376,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -393,7 +393,7 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -412,7 +412,7 @@ TEST_F(HloComputationTest, CycleDetection) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); @@ -440,7 +440,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); @@ -466,7 +466,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -484,107 +484,6 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } -TEST_F(HloComputationTest, Reachability) { - // Test reachability of a non-trivial computation: - // - // const1 const2 - // | | - // | +-------+ - // | | | - // add .. negate - // | . | - // | .... exp - // | | - // +---+ +-+---+ - // | | | - // multiply copy - // - // There is a control dependency from 'add' to 'exp'. - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - - auto module = CreateNewModule(); - auto computation = - module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); - - TF_CHECK_OK(add->AddControlDependencyTo(exp)); - auto reachability = computation->ComputeReachability(); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_TRUE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_TRUE(reachability->IsReachable(constant1, copy)); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_TRUE(reachability->IsReachable(constant2, negate)); - EXPECT_TRUE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_TRUE(reachability->IsReachable(constant2, copy)); - - EXPECT_FALSE(reachability->IsReachable(exp, constant1)); - EXPECT_FALSE(reachability->IsReachable(exp, constant2)); - EXPECT_FALSE(reachability->IsReachable(exp, add)); - EXPECT_FALSE(reachability->IsReachable(exp, negate)); - EXPECT_TRUE(reachability->IsReachable(exp, exp)); - EXPECT_TRUE(reachability->IsReachable(exp, mul)); - EXPECT_TRUE(reachability->IsReachable(exp, copy)); - - EXPECT_FALSE(reachability->IsReachable(mul, constant1)); - EXPECT_FALSE(reachability->IsReachable(mul, constant2)); - EXPECT_FALSE(reachability->IsReachable(mul, add)); - EXPECT_FALSE(reachability->IsReachable(mul, negate)); - EXPECT_FALSE(reachability->IsReachable(mul, exp)); - EXPECT_TRUE(reachability->IsReachable(mul, mul)); - EXPECT_FALSE(reachability->IsReachable(mul, copy)); - - EXPECT_TRUE(reachability->IsConnected(constant1, copy)); - EXPECT_TRUE(reachability->IsConnected(copy, constant1)); - EXPECT_FALSE(reachability->IsConnected(negate, add)); - EXPECT_FALSE(reachability->IsConnected(add, negate)); - - // Remove the control dependency then update and verify the reachability map - ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); - computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_FALSE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_FALSE(reachability->IsReachable(constant1, copy)); - - // Change a use within the graph then update and verify the reachability map - ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); - computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_FALSE(reachability->IsReachable(constant2, negate)); - EXPECT_FALSE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_FALSE(reachability->IsReachable(constant2, copy)); -} - TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,7 +505,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -641,7 +540,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -677,7 +576,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -700,27 +599,5 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -TEST_F(HloComputationTest, ChannelReachability) { - const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); - HloComputation::Builder builder("ChannelReachability"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); - auto send = - builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); - auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); - auto recv = - builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); - auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(recv_done)); - auto reachability = computation->ComputeReachability(); - EXPECT_TRUE(reachability->IsReachable(param, recv_done)); - EXPECT_FALSE(reachability->IsReachable(send, recv)); - EXPECT_FALSE(reachability->IsReachable(send_done, recv)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 4f898ce61c3f36e83e4b13130a404dbb4a2c36c6..5e37883d3d8d5067bab873ac6b5f732e7360c5fa 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -52,8 +52,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, and AfterAll operation. - // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // Skip Constant, Parameter, Tuple, AfterAll operation. + // Tuple constants are not directly supported by any backends, hence + // folding Tuple is not useful and would in fact be expanded back into + // kTuple by Algebraic Simplifier. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this // special case is not necessary. @@ -63,6 +65,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { instruction->opcode() == HloOpcode::kAfterAll) { continue; } + // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index e45f905f7152c37a9ab2b41d407310671310c2a3..d12f920722e20a3390a99f74c8a10c7c9e3fdf6c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloVerifiedTestBase; +using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -46,13 +46,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -67,13 +67,13 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -88,13 +88,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -130,11 +130,11 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -157,11 +157,11 @@ TEST_F(HloConstantFoldingTest, Slice) { Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); builder.AddInstruction(HloInstruction::CreateSlice( shape, literal_instruction, slice_start, slice_limits, slice_strides)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -182,11 +182,11 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -219,27 +219,28 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - ParseAndVerifyModule(kConstantFoldReduce); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_TRUE(result); - EXPECT_EQ(6, module() - .entry_computation() + EXPECT_EQ(6, m->entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - ParseAndVerifyModule(kConstantFoldReduce); - HloInstruction* add = module().computations().begin()->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); + HloInstruction* add = m->computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce()); } const char* const kConstantFoldLargePad = R"( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 108aeea097d7170d236b988c414b517a1a284640..fdfb38b858c32ba5b092ec2db84d4bac487c3e78 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -269,7 +269,7 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(map->to_apply())); + ProcessNestedSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64 element_count = ShapeUtil::ElementsIn(map->shape()); @@ -285,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. // This counts the number of times the reduction function is applied, so it @@ -311,7 +311,7 @@ Status HloCostAnalysis::HandleReduceWindow( auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each // output element there are window_size - 1 reductions to perform. @@ -336,9 +336,9 @@ Status HloCostAnalysis::HandleSelectAndScatter( // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, - ProcessSubcomputation(instruction->select())); + ProcessNestedSubcomputation(instruction->select())); TF_ASSIGN_OR_RETURN(const Properties scatter_properties, - ProcessSubcomputation(instruction->scatter())); + ProcessNestedSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter // source element there are window_size - 1 select computations to perform and @@ -574,7 +574,7 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); + ProcessNestedSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -595,7 +595,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, - ProcessSubcomputation(call->to_apply())); + ProcessUnnestedSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -624,13 +624,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. - // - // TODO(b/26346211): Improve the cost analysis for while nodes. TF_ASSIGN_OR_RETURN(const Properties body_properties, - ProcessSubcomputation(xla_while->while_body())); + ProcessUnnestedSubcomputation(xla_while->while_body())); - TF_ASSIGN_OR_RETURN(const Properties condition_properties, - ProcessSubcomputation(xla_while->while_condition())); + TF_ASSIGN_OR_RETURN( + const Properties condition_properties, + ProcessUnnestedSubcomputation(xla_while->while_condition())); current_properties_.clear(); for (const auto& property : body_properties) { @@ -647,10 +646,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { // Compute the cost of the true and false computations and take the maximum // from those for each property. - TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, - ProcessSubcomputation(conditional->true_computation())); - TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, - ProcessSubcomputation(conditional->false_computation())); + TF_ASSIGN_OR_RETURN( + const Properties true_computation_properties, + ProcessUnnestedSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN( + const Properties false_computation_properties, + ProcessUnnestedSubcomputation(conditional->false_computation())); current_properties_ = true_computation_properties; for (const auto& property : false_computation_properties) { if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { @@ -680,7 +681,7 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { const int64 element_count = ShapeUtil::ElementsIn(scatter->operand(2)->shape()); TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(scatter->to_apply())); + ProcessNestedSubcomputation(scatter->to_apply())); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * element_count; @@ -689,6 +690,11 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { return Status::OK(); } +Status HloCostAnalysis::HandleGetDimensionSize( + const HloInstruction* /*get_size*/) { + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -725,10 +731,19 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } -StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation) { +StatusOr +HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); +} + +StatusOr +HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + hlo_properties_.insert(visitor.hlo_properties_.begin(), + visitor.hlo_properties_.end()); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 46b4bbeef222e6de581360fc01b293e812f1dedd..8ced9d776e150ac587e9ac3ed0beffbc38dc5503 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -107,6 +107,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; Status HandleScatter(const HloInstruction* scatter) override; + Status HandleGetDimensionSize(const HloInstruction* get_size) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -153,7 +154,24 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Returns the properties computed from visiting the computation rooted at the // given hlo. - StatusOr ProcessSubcomputation(HloComputation* computation); + // + // The difference between ProcessNestedSubcomputation and + // ProcessUnnestedSubcomputation is that we expect to get profile results for + // an unnested subcomputation's individual instructions, while we expect that + // a nested subcomputation is completely subsumed by its parent. + // + // For example, subcomputations inside kFusion and kMap are considered nested, + // while subcomputations inside kWhile and kConditional are considered + // unnested. + // + // Another way of thinking of this is, kFusion is implemented on the GPU + // backend using just one GPU kernel, while kWhile's body is implemented as a + // sequence of kernels, one for each HLO therein. Backends don't necessarily + // need to follow this same implementation strategy, but we assume they do for + // the purposes of this platform-generic cost analysis. + StatusOr ProcessNestedSubcomputation(HloComputation* computation); + StatusOr ProcessUnnestedSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9acee892d5993be3498d51ed66d7fa4647d7de88..ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8..aaa9ec60eb3c4e0159ed40b37d772e0973d306ec 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,22 +19,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -class HloCreationUtilsTest : public HloVerifiedTestBase { +class HloCreationUtilsTest : public HloTestBase { protected: - HloModule* CreateModuleWithProgramShape( + std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, absl::Span input_shape_dims, absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - auto module = CreateNewModule("test"); + auto module = CreateNewVerifiedModule("test"); *entry_computation = module->AddEntryComputation( CreateComputationWithSignature({&input_shape}, output_shape, "entry") .ValueOrDie()); @@ -47,10 +47,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); @@ -67,9 +66,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + auto module = CreateModuleWithProgramShape( + S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, @@ -92,10 +90,9 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{1, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -113,10 +110,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -134,10 +130,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{1, 1}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -154,10 +149,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6}, + /*output_shape_dims=*/{3, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, ExpandFirstDimIntoNDims(param, {3, 1, 2})); @@ -176,10 +170,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{6}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -197,10 +190,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -218,10 +210,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(F32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9b18b0284f63c25934c1b7118dc8973caa62cadc..1eb0260468c4560985027947e89c62cc21139e7e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloVerifiedTestBase { +class HloCseTest : public HloTestBase { protected: HloCseTest() {} }; @@ -59,13 +59,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); @@ -89,14 +89,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); @@ -121,14 +121,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { shape_r0, HloOpcode::kAdd, root, constants[i])); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -201,7 +201,7 @@ TEST_F(HloCseTest, NonscalarConstants) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( {common_constant1, common_constant2, uncommon_constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -233,14 +233,14 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -277,21 +277,21 @@ index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -327,20 +327,20 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -369,22 +369,21 @@ condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), condition=%condition.1, body=%body - } - - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -411,13 +410,14 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -439,14 +439,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -470,14 +470,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -488,7 +488,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { TEST_F(HloCseTest, FusionInternalCSE) { // Test that we can CSE expressions that live within a fusion node // computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); @@ -512,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -554,14 +554,14 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -586,7 +586,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -595,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -607,7 +607,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // rng_function is an impure function because it does RNG. HloComputation* rng_function = nullptr; @@ -649,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -659,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule m add_computation { @@ -680,11 +680,12 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })"); + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -697,19 +698,19 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -730,11 +731,12 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})"); +})"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - const HloInstruction* sub = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + const HloInstruction* sub = m->entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 909853106d57d181e85e3e4134b4039be2b176f5..e8eb7066f96537ff7d5a932434852bc4cf209281 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1884,7 +1884,7 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, class HloDataflowAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -2476,7 +2476,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -2511,7 +2511,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 3b5cde2996c4195ef458662cd21de85a832d8d55..1fa4259a3e42286cbc911907eea563e6ca6f8611 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -80,7 +80,7 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -215,7 +215,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { // Tests that a while instruction with an infeed (effectul instruction) in its // body is not removed, even its user count is 0. TEST_F(HloDceTest, CalledComputationWithSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Condition computation of a while instruction. @@ -270,7 +270,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { // Tests that a nested call instruction with a side effect is not removed. TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Nested called computation with a side effect. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); @@ -364,7 +364,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { } TEST_F(HloDceTest, KeepUsedSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index b90e8db23398d23e886d2d1fe68de8bb187d9c3a..acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" @@ -22,13 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloVerifiedTestBase { +class HloDomainTest : public HloTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,13 +63,6 @@ class HloDomainTest : public HloVerifiedTestBase { } return false; } - - StatusOr ParseModule(absl::string_view hlo_string) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_string, config); - return &module(); - } }; // Dummy DomainMetadata implementation which create kDomain boundaries around @@ -144,31 +136,32 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,11 +179,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(!isolator_changed); } @@ -213,26 +207,27 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); - EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -250,11 +245,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_FALSE(isolator_changed); } @@ -273,15 +269,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module, "c"); + HloInstruction* add = FindInstruction(module.get(), "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -304,41 +301,42 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module)); + sharding_isolator.Run(module.get())); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -359,16 +357,17 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); - EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. @@ -384,7 +383,7 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data"); ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); @@ -410,7 +409,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); struct Assignment { @@ -446,25 +445,26 @@ ENTRY entry { sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); - EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple")); // Remove %tuple and %gte (tuple simplification) - HloInstruction* gte = FindInstruction(module, "gte"); - HloInstruction* tuple = FindInstruction(module, "tuple"); + HloInstruction* gte = FindInstruction(module.get(), "gte"); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); const HloInstruction* root = module->entry_computation()->root_instruction(); @@ -486,11 +486,11 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto hlo_string = module->ToString(); - ASSERT_TRUE(ParseModule(hlo_string).status().ok()); + ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok()); } // Tuple inputs are domain instructions. @@ -507,20 +507,21 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tpl instruction, in order to test domain sharding // application. - auto tpl = FindInstruction(module, "tpl"); + auto tpl = FindInstruction(module.get(), "tpl"); tpl->clear_sharding(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), @@ -555,36 +556,37 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } // Emulate instructions inserted at top and bottom within nested tuple domain. @@ -603,15 +605,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tuple.0 instruction, in order to test domain sharding // application. - auto tuple0 = FindInstruction(module, "tuple.0"); + auto tuple0 = FindInstruction(module.get(), "tuple.0"); tuple0->clear_sharding(); // Insert the following instructons above and below tuple.0, to emulate other @@ -655,7 +658,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_TRUE(tuple0->has_sharding()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 7fcafafc097a623686ca98a7cb3c6256c7904f6d..9783f0574f50ba5542b82d36da899f968ce0e45c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1279,10 +1279,10 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, key_value_vector.push_back( std::make_pair(keys_data[i], values_data[i])); } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); + std::stable_sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); std::vector result_keys; // We use a InlinedVector here because we need to convert it to an // absl::Span later, and this would not work with std::vector. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 07f8d0aad4af0b07303b4e485b3630cc75bcb519..d751f40fff872b831338dc8aa08a04cb00d2838c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -221,16 +221,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const Literal& operand_literal) { const auto shape = instruction->shape(); const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape), - ShapeUtil::HumanString(operand->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); Literal result(shape); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 608a42bb60702aa075daca39535ca1672dcc5467..d95b6ad04f2c446b423a3aaef4de333ed2968883 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,9 +50,9 @@ namespace { static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloVerifiedTestBase { + public HloTestBase { protected: - HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -60,14 +60,14 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(&module()).ValueOrDie(); + type_converter.Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - // Evaluate function that takes in a local module instead of using module_ - // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // Evaluate function that takes in a local module instead of using m_ + // that is in HloTestBase. Once m_ in HloTestBase is // removed, this should be the default Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { @@ -88,7 +88,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -108,7 +108,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -116,6 +116,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } bool use_bfloat16_; + std::unique_ptr m_ = CreateNewVerifiedModule(); }; #define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ @@ -135,7 +136,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -156,7 +157,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -181,7 +182,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -322,7 +323,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(args); @@ -346,7 +347,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -367,7 +368,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, {1, 2})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -386,7 +387,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -406,7 +407,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -428,7 +429,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -448,7 +449,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -468,7 +469,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -503,7 +504,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -530,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -574,7 +575,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -619,7 +620,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -658,7 +659,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -704,7 +705,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -748,7 +749,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -802,7 +803,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -857,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -941,7 +942,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1019,7 +1020,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1079,7 +1080,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1143,7 +1144,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1215,7 +1216,7 @@ TEST_P(HloEvaluatorTest, b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1286,7 +1287,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1297,11 +1298,12 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1319,12 +1321,12 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m->AddEmbeddedComputation(add_computation.Build()); HloInstruction* reduce_instruction = b.AddInstruction( HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{0}, add_func)); - module().AddEntryComputation(b.Build()); + m->AddEntryComputation(b.Build()); HloEvaluator hlo_eval; Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); @@ -1337,7 +1339,7 @@ void BM_ReducePrecisely(int num_iters) { tensorflow::testing::StopTiming(); HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1396,14 +1398,14 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1438,7 +1440,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1455,7 +1457,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1490,7 +1492,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1507,7 +1509,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1541,7 +1543,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1564,7 +1566,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1594,7 +1596,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1625,7 +1627,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1657,7 +1659,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1691,7 +1693,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1727,7 +1729,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1764,7 +1766,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1800,7 +1802,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1839,7 +1841,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1877,7 +1879,7 @@ TEST_P(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1966,7 +1968,7 @@ ENTRY main { slice_sizes={1, 3} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -1990,7 +1992,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -2014,7 +2016,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2039,7 +2041,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2066,7 +2068,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2092,7 +2094,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); @@ -2115,7 +2117,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2139,7 +2141,7 @@ ENTRY main { slice_sizes={1, 0} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), @@ -2161,7 +2163,7 @@ ENTRY main { slice_sizes={1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = @@ -2192,7 +2194,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2223,7 +2225,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2256,7 +2258,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2288,7 +2290,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2320,7 +2322,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); @@ -2354,7 +2356,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2386,7 +2388,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2418,7 +2420,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2455,7 +2457,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2491,7 +2493,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2523,7 +2525,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2555,7 +2557,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); @@ -2585,7 +2587,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal scatter_indices = @@ -2736,7 +2738,7 @@ ENTRY main { ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); @@ -2754,7 +2756,7 @@ ENTRY main { ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index ebed875eb4954bc9a9da3f182005fa3d44326493..b87fc3e34012e75ee07bff6c1e113dce404f83cb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,9 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloOpcodeString(hlo_instruction->opcode())); } - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive type. - template ::value>::type* = nullptr> @@ -596,7 +593,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) { + Status HandleDivide(HloInstruction* divide) override { return HandleDivide(divide); } @@ -1556,10 +1553,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& row_data = row_to_sort.data(); std::vector result_data(row_data.begin(), row_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); + std::stable_sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), {sort_dim_elements})); sorted_row.PopulateR1(absl::Span(result_data)); @@ -2546,12 +2543,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || - std::is_same::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - std::vector data(iota->shape().dimensions(iota->iota_dimension())); + // Avoid using std::vector since std::vector does not convert to + // absl::Span. + absl::InlinedVector data( + iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); auto result = LiteralUtil::CreateR1(data); @@ -2568,9 +2567,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + !(std::is_integral::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return InvalidArgument("Unsupported type for iota"); } @@ -2722,17 +2720,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -2756,19 +2745,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape()), - ShapeUtil::HumanString(ehs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..631b3ad735f369922d10b37d11e2a1b1ba117e6b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { + +namespace { + +StatusOr ReplaceGetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); + uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instr, new_instr)); + return true; +} + +} // namespace + +StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { + bool changed = false; + HloProto proto; + *proto.mutable_hlo_module() = module->ToProto(); + for (auto* computation : module->computations()) { + // Replacing instructions will change the instruction list in the + // computation. So instead of iterating computation->instructions() + // directly, we make a copy of the list to avoid use-after-free. + std::vector instrs(computation->instruction_count()); + absl::c_copy(computation->instructions(), instrs.begin()); + for (auto instruction : instrs) { + TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + changed = changed || replaced; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..30f44c23a835b3bcc935caaa917e040e07c4e703 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass to replace a kGetDimensionSize instruction with a constant instruction. +class HloGetDimensionSizeRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "hlo-get-dimension-size-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86aebdd5b64240e6e07d8e8050c0c8681cce765 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloGetDimensionSizeRewriterTest : public HloTestBase { + protected: + HloGetDimensionSizeRewriterTest() {} +}; + +TEST_F(HloGetDimensionSizeRewriterTest, Ok) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = u32[] get-dimension-size(p), dimensions={0} + size1 = u32[] get-dimension-size(p), dimensions={1} + ROOT mul = u32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = u32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 13a74fd8a115c5dc9a9518b226dfee4445cc7180..05cc1593e4ef4fc52b94e0536628645b1fa2abbc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1043,6 +1043,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGetDimensionSize: return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f6ed86b41650fd331201814559386ff644092c23..cd95052580b3d203c2d2a586bc4d9fdbb9d19bf4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -312,6 +312,10 @@ StatusOr> HloInstruction::CreateFromProto( proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { + TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) && + (ShapeUtil::TupleElementCount(proto.shape()) == 2)) + << "Infeed should have a tuple shape with 2 operands, but has: " + << proto.shape(); const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -530,6 +534,12 @@ StatusOr> HloInstruction::CreateFromProto( absl::make_unique(exit_hlo_sharding)); break; } + case HloOpcode::kGetDimensionSize: + TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.dimensions_size() == 1); + instruction = CreateGetDimensionSize(proto.shape(), operands(0), + proto.dimensions(0)); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -1001,6 +1011,14 @@ HloInstruction::CreateSelectAndScatter( broadcast_dimensions); } +/* static */ std::unique_ptr +HloInstruction::CreateGetDimensionSize(const Shape& shape, + HloInstruction* operand, + int64 dimension) { + return absl::make_unique(shape, operand, + dimension); +} + /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -1109,7 +1127,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) { void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { - if (sharding_ != nullptr) { + if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType( + shape_, derived_instruction->shape())) { + // Only copy sharding if the shape of the two instruction is compatible + // because copying it between differently shaped instructions can produce + // invalid shardings. derived_instruction->set_sharding(*sharding_); } else { derived_instruction->clear_sharding(); @@ -1268,6 +1290,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kIota: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1715,6 +1738,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kScatter: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1876,6 +1900,11 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } +HloInstruction* HloInstruction::while_init() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return operands_[0]; +} + HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2440,6 +2469,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAfterAll(this); case HloOpcode::kIota: return visitor->HandleIota(this); + case HloOpcode::kGetDimensionSize: + return visitor->HandleGetDimensionSize(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2597,36 +2628,6 @@ Status HloInstruction::AcceptWithOperandOrder( return Status::OK(); } -namespace { - -// Returns true if the given order is a topological sort of the instructions -// it contains. -bool OrderIsTopologicalSort(const std::vector& order) { - // Create a map from instruction to its position in 'order'. - std::unordered_map order_position; - for (int i = 0; i < order.size(); i++) { - if (!order_position.insert({order[i], i}).second) { - // Instruction order[i] is duplicated in the order. - return false; - } - } - // Verify that the operand of each instruction in the order is also in the - // order *and* the operand's position is earlier (defs are before uses for - // all ops). - for (auto* instruction : order) { - for (auto* operand : instruction->operands()) { - if (!ContainsKey(order_position, operand) || - order_position.at(operand) >= order_position.at(instruction)) { - return false; - } - } - } - - return true; -} - -} // namespace - Status HloInstruction::Accept( const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); @@ -2639,49 +2640,7 @@ Status HloInstruction::Accept( return this->Accept(&visitor); } -Status HloInstruction::AcceptOrdered( - DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; - TF_RET_CHECK(OrderIsTopologicalSort(order)); - - // Compute the predecessors of this instruction. - std::unordered_set predecessors; - TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { - predecessors.insert(instruction); - return Status::OK(); - })); - - for (auto* const_instruction : order) { - if (!ContainsKey(predecessors, const_instruction)) { - // Instruction is not a predecessors of 'this'. - continue; - } - - // The visitor can mark instructions as visited to skip particular - // instructions. - if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO %" << const_instruction->name() - << " as it was already visited."; - continue; - } - - // TODO(b/78350259): Eliminate const laundering. - HloInstruction* instruction = - const_cast(const_instruction); - - TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO %" << instruction->name(); - TF_RETURN_IF_ERROR(instruction->Visit(visitor)); - visitor->SetVisited(*instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); - } - - return visitor->FinishVisit(this); -} - -const Shape& HloInstruction::shape() const { - return shape_; -} +const Shape& HloInstruction::shape() const { return shape_; } std::vector HloInstruction::OperandIndices( const HloInstruction* operand) const { @@ -3080,6 +3039,10 @@ int64 HloInstruction::concatenate_dimension() const { return Cast(this)->concatenate_dimension(); } +int64 HloInstruction::dimension() const { + return Cast(this)->dimension(); +} + bool HloInstruction::IsRank2Transpose() const { auto transpose = DynCast(this); return transpose != nullptr && transpose->IsRank2Transpose(); @@ -3259,6 +3222,11 @@ absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } +void HloInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + return Cast(this)->set_all_reduce_id(all_reduce_id); +} + const ConvolutionDimensionNumbers& HloInstruction::convolution_dimension_numbers() const { if (auto convolution = DynCast(this)) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 15a4da8dbe0053aad314989a6718ebd61532ab8b..95ad29235afa36dc4091feec54cd4b0f5f24048f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -767,6 +767,9 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr CreateToken(); + static std::unique_ptr CreateGetDimensionSize( + const Shape& shape, HloInstruction* operand, int64 dimension); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -880,11 +883,15 @@ class HloInstruction { return false; } - // Use an explicit loop rather than ContainerEquals, because copying around - // std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; + // Two AllReduces are Identical if they have the same all_reduce_id. + // Their operands don't have to be Identical. + if (!this->IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } } } @@ -954,16 +961,6 @@ class HloInstruction { Status Accept( const std::function& visitor_func) const; - // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain at least the set of instructions - // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). Instructions contained in 'order' which are not in the set of - // instructions rooted at this node are ignored. 'order' must also be a valid - // topological sort of these instructions (defs appear before uses) though - // need not be a DFS post-order. - Status AcceptOrdered(DfsHloVisitor* visitor, - const std::vector& order); - // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); @@ -1004,6 +1001,8 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); + HloInstruction* while_init() const; + // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -1324,6 +1323,9 @@ class HloInstruction { // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; + // Delegates to HloGetDimensionSizeInstruction::dimension. + int64 dimension() const; + // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; @@ -1442,6 +1444,7 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns data on the window in a windowed operation such as // convolution. diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d93351fe0435b5f29035dc4ea0621a8c576bfd5a..8048e332cb57747286758b75773b29ba154aa888 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,7 +39,7 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloVerifiedTestBase { +class HloInstructionTest : public HloTestBase { protected: Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -151,7 +151,7 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); @@ -188,7 +188,7 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -221,7 +221,7 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -256,7 +256,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -305,7 +305,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -327,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -375,7 +375,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. @@ -416,7 +416,7 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -451,7 +451,7 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -479,7 +479,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -516,7 +516,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -546,7 +546,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -611,7 +611,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; @@ -629,7 +629,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -647,7 +647,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -669,7 +669,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -692,7 +692,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -749,7 +749,7 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); @@ -817,7 +817,7 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -977,7 +977,7 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); int visit_num = 0; @@ -1006,7 +1006,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); @@ -1016,7 +1016,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { } TEST_F(HloInstructionTest, MapIsElementwise) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); HloComputation::Builder builder(TestName()); HloComputation::Builder map_builder("id"); @@ -1067,7 +1067,7 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); @@ -1108,7 +1108,7 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); @@ -1151,7 +1151,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1192,7 +1192,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1204,7 +1204,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { } TEST_F(HloInstructionTest, FusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1226,7 +1226,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1330,7 +1330,7 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( @@ -1373,7 +1373,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1408,7 +1408,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1443,7 +1443,7 @@ TEST_F(HloInstructionTest, StringifyScatter) { update_builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* update_computation = module->AddEmbeddedComputation(update_builder.Build()); @@ -1495,7 +1495,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1531,7 +1531,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1587,7 +1587,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 88495e80000c4f87a778c4fad747f6bdf09b7a14..ed3b2f1564103969a1092f3215f8b6a377d2d2ae 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -370,6 +370,11 @@ HloAllReduceInstruction::HloAllReduceInstruction( AppendComputation(reduce_computation); } +void HloAllReduceInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + all_reduce_id_ = all_reduce_id; +} + HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. @@ -2349,4 +2354,43 @@ HloInstructionProto HloDomainInstruction::ToProto() const { return proto; } + +HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction( + const Shape& shape, HloInstruction* operand, int64 dimension) + : HloInstruction(HloOpcode::kGetDimensionSize, shape), + dimension_(dimension) { + AppendOperand(operand); +} + +HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(dimension()); + return proto; +} + +std::vector HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + return {StrCat("dimensions={", dimension(), "}")}; +} + +bool HloGetDimensionSizeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return dimension() == casted_other.dimension(); +} + +std::unique_ptr +HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + if (new_operands.size() != 1) { + LOG(FATAL) << "expects 1 operand"; + } + return absl::make_unique( + shape, new_operands[0], dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index bf4daf2be47ed06d2b88a331a56149d38fa646b3..0b07341cb94c1391c787ec8e0f5a3f17dccc96b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -252,6 +252,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { } absl::optional all_reduce_id() const { return all_reduce_id_; } + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1385,6 +1386,33 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; }; + +class HloGetDimensionSizeInstruction : public HloInstruction { + public: + explicit HloGetDimensionSizeInstruction(const Shape& shape, + HloInstruction* operand, + int64 dimension); + + // Returns the dimension sizes or numbers associated with this instruction. + int64 dimension() const { return dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 1717770301e3666b0a1c23d20b7f2e3bac5f62e4..170ec93a334903cdc314f1950675ef30bc4cda5a 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -165,6 +165,7 @@ namespace opcode_matchers { } HLO_MATCHER(Abs); HLO_MATCHER(Add); +HLO_MATCHER(AllToAll); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); HLO_MATCHER(BatchNormGrad); @@ -178,6 +179,7 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 5cee865b7ad34eded1743d9d5455bb40febf6182..d2740bcce26f04c5d7c8b64cfdaea53e3c697855 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -73,7 +73,7 @@ class ListScheduler { // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. static StatusOr Run( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -98,7 +98,7 @@ class ListScheduler { // comparison operators. using Priority = std::pair; - ListScheduler(const HloComputation& computation, + ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -111,7 +111,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( @@ -126,13 +126,13 @@ class ListScheduler { // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } @@ -141,7 +141,7 @@ class ListScheduler { // Buffers live out of the computation have an implicit use at the end of // the computation. for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) + points_to_analysis.GetPointsToSet(computation->root_instruction()) .CreateFlattenedSet()) { ++unscheduled_use_count_[live_out_buffer]; } @@ -157,7 +157,7 @@ class ListScheduler { // HloInstruction, plus some cached metadata, saved for the purposes of making // BytesFreedIfScheduled fast. struct ReadyListEntry { - const HloInstruction* instruction; + HloInstruction* instruction; // The total size of all buffers defined by this instruction. int64 bytes_defined; @@ -171,7 +171,7 @@ class ListScheduler { }; // Creates a ReadyListEntry for the given instruction. - ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) { ReadyListEntry entry; entry.instruction = instruction; @@ -250,13 +250,13 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. absl::flat_hash_map unscheduled_pred_count; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { + for (HloInstruction* user : instruction->users()) { unscheduled_pred_count[user]++; } - for (const HloInstruction* succ : instruction->control_successors()) { + for (HloInstruction* succ : instruction->control_successors()) { unscheduled_pred_count[succ]++; } } @@ -275,7 +275,7 @@ class ListScheduler { ready_instructions[inst] = it; }; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { if (instruction->operands().empty() && instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); @@ -287,7 +287,7 @@ class ListScheduler { // schedule. auto best_it = ready_queue.end(); --best_it; - const HloInstruction* best = best_it->second.instruction; + HloInstruction* best = best_it->second.instruction; VLOG(2) << "Schedule instruction: " << best->ToShortString() << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); @@ -348,13 +348,13 @@ class ListScheduler { } } } - CHECK_EQ(schedule.size(), computation_.instruction_count()); - CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); + CHECK_EQ(schedule.size(), computation_->instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count()); return schedule; } - const HloComputation& computation_; + HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; // Computations are analyzed in post-order. When scheduling an instruction @@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes( } StatusOr ScheduleComputationHelper( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, const absl::flat_hash_map& memory_by_computation) { - VLOG(2) << "Computation: " << computation.name(); + VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, size_function, memory_by_computation); @@ -404,17 +404,17 @@ StatusOr ScheduleComputationHelper( } // namespace StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->instruction_count(); + int64 total_hlos = computation->parent()->instruction_count(); absl::flat_hash_map extra_users; absl::flat_hash_map total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; total_sizes[hlo] = 0; @@ -448,8 +448,8 @@ StatusOr DFSMemoryScheduler( total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } - CHECK_EQ(extra_users.size(), computation.instruction_count()); - CHECK_EQ(total_sizes.size(), computation.instruction_count()); + CHECK_EQ(extra_users.size(), computation->instruction_count()); + CHECK_EQ(total_sizes.size(), computation->instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -459,7 +459,7 @@ StatusOr DFSMemoryScheduler( sequence.push_back(hlo); return Status::OK(); }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&extra_users, &total_sizes](const HloInstruction* a, const HloInstruction* b) { if (extra_users[a] != extra_users[b]) { @@ -470,12 +470,12 @@ StatusOr DFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instruction_count()); + CHECK_EQ(sequence.size(), computation->instruction_count()); return sequence; } // namespace xla StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -485,16 +485,16 @@ StatusOr ListMemoryScheduler( } StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); + return HloInstructionSequence(computation->MakeInstructionPostOrder()); } StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -513,7 +513,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, HeapSimulator::MinimumMemoryForComputation( - computation, list_sequence, points_to_analysis, + *computation, list_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); @@ -522,7 +522,7 @@ StatusOr DefaultMemoryScheduler( size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, HeapSimulator::MinimumMemoryForComputation( - computation, dfs_sequence, points_to_analysis, + *computation, dfs_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); @@ -532,7 +532,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, HeapSimulator::MinimumMemoryForComputation( - computation, post_order_sequence, points_to_analysis, + *computation, post_order_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -555,17 +555,17 @@ StatusOr DefaultMemoryScheduler( } StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); + HloSchedule schedule(module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); + TuplePointsToAnalysis::Run(module)); absl::flat_hash_map memory_by_computation; - for (const auto* computation : module.MakeComputationPostOrder()) { + for (auto* computation : module->MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( - *computation, *points_to_analysis, size_function, + computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( @@ -583,11 +583,11 @@ StatusOr ScheduleModule( } StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function) { - CHECK(!computation.IsFusionComputation()); + CHECK(!computation->IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); + TuplePointsToAnalysis::Run(computation->parent())); absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); @@ -600,7 +600,24 @@ HloMemoryScheduler::HloMemoryScheduler( StatusOr HloMemoryScheduler::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, size_function_, algorithm_)); + ScheduleModule(module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr HloTrivialScheduler::Run(HloModule* module) { + HloSchedule schedule(module); + for (HloComputation* computation : module->MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + HloInstructionSequence& computation_sequence = + schedule.GetOrCreateSequence(computation); + TF_RETURN_IF_ERROR(computation->Accept( + [&computation_sequence](HloInstruction* instruction) { + computation_sequence.push_back(instruction); + return Status::OK(); + })); + } + } TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index a4c1d3db8170a1725043def576f913e09b352e5d..7227bfb27c74758d2b79e404afc9eb97a1ca894d 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -36,14 +36,14 @@ namespace xla { // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function( - const HloComputation&, const TuplePointsToAnalysis&, + HloComputation*, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -51,7 +51,7 @@ StatusOr ListMemoryScheduler( // DFS-order scheduler StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -59,7 +59,7 @@ StatusOr DFSMemoryScheduler( // Naive Post Order scheduler StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -69,7 +69,7 @@ StatusOr PostOrderMemoryScheduler( // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -79,13 +79,13 @@ StatusOr DefaultMemoryScheduler( // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function); // A pass which schedules the HLO instructions in a module. The HloModule's @@ -108,6 +108,15 @@ class HloMemoryScheduler : public HloModulePass { MemorySchedulerAlgorithm algorithm_; }; +// A pass which produces a naive, but correct schedule. The schedule is produced +// using a DFS traversal of the graph with no attempt to minimize memory use. +class HloTrivialScheduler : public HloModulePass { + public: + absl::string_view name() const override { return "hlo-trivial-scheduler"; } + + StatusOr Run(HloModule* module) override; +}; + // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. class HloDescheduler : public HloModulePass { diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 214119fba881c4411a262cd4227b5cc49cef0d14..bc0d7e2bc00eab014f2660c95a51b966642eaee9 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK(module->schedule().Verify()); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -124,9 +124,9 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -172,15 +172,16 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -218,19 +219,19 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -242,7 +243,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { } TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); // param != 0 @@ -252,7 +253,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction::CreateParameter(0, r1f32, "cond_param")); HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + LiteralUtil::CreateR1({0, 0, 0, 0}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -309,5 +310,40 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { .ValueOrDie()); } +TEST_F(HloSchedulingTest, TrivialScheduler) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + EXPECT_FALSE(module->has_schedule()); + TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + + // Verify that a clone of the module also has a schedule. + std::unique_ptr clone = module->Clone(); + ASSERT_TRUE(clone->has_schedule()); + TF_ASSERT_OK(clone->schedule().Verify()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index bcd709c973920d36f6b7f16a1a1a38dbf7fdf0cf..59f44475df55311992d41aecfb1f2f4e53a2e316 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -242,6 +242,8 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_host_program_shape() = entry_computation_layout().ComputeProgramShape(); *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); + *proto.mutable_dynamic_parameter_binding() = + dynamic_parameter_binding().ToProto(); return proto; } @@ -325,6 +327,10 @@ StatusOr> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. + TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_, + DynamicParameterBinding::CreateFromProto( + proto.dynamic_parameter_binding())); + absl::flat_hash_set computation_names; absl::flat_hash_set instruction_names; absl::flat_hash_set computation_ids; @@ -559,11 +565,28 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { std::unique_ptr HloModule::Clone(const HloModuleConfig& config, const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config); + auto module = absl::make_unique( + absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); + + if (has_schedule() && schedule().Verify().ok()) { + HloSchedule clone_schedule(module.get()); + for (HloComputation* computation : computations()) { + if (schedule().is_computation_scheduled(computation)) { + HloInstructionSequence& clone_sequence = + clone_schedule.GetOrCreateSequence( + context.GetComputation(computation)); + for (const HloInstruction* instruction : + schedule().sequence(computation).instructions()) { + clone_sequence.push_back(context.GetInstruction(instruction)); + } + } + } + TF_CHECK_OK(module->set_schedule(std::move(clone_schedule))); + } return module; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 8a1f999e3ab076b87a651a915f4de93320e7067f..66622a1d260c28078d69b01b858fd292b697805b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -103,11 +104,7 @@ class HloModule { HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module. - const HloComputation* entry_computation() const { - CHECK_NE(nullptr, entry_computation_); - return entry_computation_; - } - HloComputation* entry_computation() { + HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } @@ -232,6 +229,16 @@ class HloModule { return input_output_alias_config_; } + // DynamicParameterBinding holds the list of bindings that indicates which + // parameter dimensions are dynamic and which parameters represent their + // runtime value. + DynamicParameterBinding& dynamic_parameter_binding() { + return dynamic_parameter_binding_; + } + const DynamicParameterBinding& dynamic_parameter_binding() const { + return dynamic_parameter_binding_; + } + // Returns an id that is unique to this module across all modules created over // the lifetime of this process. int unique_id() const { return unique_id_; } @@ -285,6 +292,9 @@ class HloModule { // alias_config indicates the alias information of input/output buffers that // are expected from the module. HloInputOutputAliasConfig input_output_alias_config_; + + // Bindings for dynamic parameter mapping. + DynamicParameterBinding dynamic_parameter_binding_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 39f38b417ab0e8b54864176d8d1e0ad1a422eca6..620cb7e01ad1a060915f5b73474f6950ab18122a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewModule(); - auto module_b = CreateNewModule(); + auto module_a = CreateNewVerifiedModule(); + auto module_b = CreateNewVerifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e6bfb8025d4bfeba1d334d1f946e33841a2da092..70c7d70b41c5c7bc94d1fac83c0fcf71f155b5f0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -83,6 +83,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetDimensionSize, "get-dimension-size") \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 23d41d91d6969ddf9062507e926ae39c1e1315d4..ca6a154809be46d6a0305c29e2b89219de408019 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -334,7 +334,7 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto* computation : module->MakeNonfusionComputations()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, HloReachabilityMap::Build(computation)); } } @@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. TF_DCHECK_OK(schedule_.Verify()); for (const auto& computation_sequence : schedule_.sequences()) { - const std::vector& order = - computation_sequence.second.instructions(); + const auto& order = computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { InsertOrDie(&order_position_, order[i], i); } @@ -374,11 +373,10 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( return order_position_.at(a) < order_position_.at(b); } -const std::vector* -SequentialHloOrdering::SequentialOrder( +const HloInstructionSequence* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { return schedule_.is_computation_scheduled(&computation) - ? &schedule_.sequence(&computation).instructions() + ? &schedule_.sequence(&computation) : nullptr; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 66313492eb2dd10ac9a6000639ddb8991b367c0f..a07214c22c0989a438f12219e136a7e76ee0dcce 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +65,7 @@ class HloOrdering { // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. - virtual const std::vector* SequentialOrder( + virtual const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const = 0; // Return the call graph of the module used to compute ordering. @@ -96,7 +97,7 @@ class PredecessorHloOrdering : public HloOrdering { // Returns nullptr indicating the computation does not have a sequential // ordering. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return nullptr; } @@ -185,7 +186,7 @@ class SequentialHloOrdering : public HloOrdering { ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override; string ToString() const override; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f..3ca77e60cd5275c22eb0e338cd5437fc44b49958 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e0011398aad133d07d31c419626e4be54228f9de..4bf287a9ed585889669c22bb61873be2887ff66a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -47,11 +47,11 @@ const double kF16max = 65504; // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. -HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { +HloSchedule ScheduleFromInstructionOrder(HloModule* module) { HloSchedule schedule(module); - for (const HloComputation* computation : module->computations()) { + for (HloComputation* computation : module->computations()) { if (!computation->IsFusionComputation()) { - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { schedule.GetOrCreateSequence(computation).push_back(instruction); } } @@ -108,7 +108,7 @@ class HloParser { bool ParseInstructionList(HloComputation** computation, const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); @@ -608,10 +608,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } - return ParseInstruciontRhs(builder, name, name_loc); + return ParseInstructionRhs(builder, name, name_loc); } -bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, +bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc) { Shape shape; HloOpcode opcode; @@ -1547,6 +1547,18 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); + case HloOpcode::kGetDimensionSize: + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateGetDimensionSize( + shape, operands[0], (*dimensions)[0])); + break; } instruction->SetAndSanitizeName(name); @@ -1806,6 +1818,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value, case U64: return SetValueInLiteralHelper(value, linear_index, literal); + case PRED: + // Bool type literals with rank >= 1 are printed in 0s and 1s. + return SetValueInLiteralHelper(static_cast(value), + linear_index, literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); @@ -2060,14 +2076,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - // TODO(congliu): bool type literals with rank >= 1 are actually - // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } lexer_.Lex(); - } else if (primitive_util::IsIntegralType(shape.element_type())) { + } else if (primitive_util::IsIntegralType(shape.element_type()) || + shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); tensorflow::int64 value; if (!ParseInt64(&value)) { @@ -2705,7 +2720,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - std::vector split1 = absl::StrSplit(str, "_"); + std::vector split1 = absl::StrSplit(str, '_'); if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; @@ -3386,7 +3401,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { // e.g. // // f32[10] fusion(...), calls={...} - if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) { return false; } } else { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 81eeb9f13bf7f06123c0b35e9f3352c197866a7a..d830fa61438239005875f785f85cf2486123ebc9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,7 +44,9 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, +// e.g., "{replicated}". StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). @@ -55,10 +57,6 @@ StatusOr ParseWindow(absl::string_view str); StatusOr ParseConvolutionDimensionNumbers( absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(absl::string_view str); - // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 85d07d8092ce19089543f5f11be9f4a58cbf132f..88682e55fb37e6cacbeaf5826286cc9f70e57e3b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -75,6 +75,18 @@ ENTRY %constant_pred () -> pred[] { )" }, +// pred array constant +{ +"ConstantPredArray", +R"(HloModule module + +ENTRY %constant_pred_array () -> pred[2,3] { + ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) +} + +)" +}, + // s32 constant { "ConstantS32", @@ -183,7 +195,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) } )" @@ -575,7 +587,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -1138,6 +1150,25 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } +)" +}, +// cross-replica-sum with all-reduce-id +{ +"CrossReplicaSumAllReduce", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add +} + )" }, // all-to-all diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index ee8cb12b231718e09f6ac0d05d7a6887f4c4d746..20384b9da6be4bab447b474f0e2240bcb277a620 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloPassPipelineTest : public HloVerifiedTestBase { +class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2d5197be9e6f69f698729e06b7506a5bc6260bcd..f968a4a94453f678f5c17e0b8d1df4aea70c93ea 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -104,5 +104,20 @@ bool IsScalarConstant(const HloInstruction* instruction) { return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape()); } +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes) { + for (const auto* instr : comp->instructions()) { + if (opcodes.count(instr->opcode())) { + return true; + } + for (const HloComputation* subcomp : instr->called_computations()) { + if (ContainsInstrWithOpcode(subcomp, opcodes)) { + return true; + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c0826a6aee1f693484207a86ec258c6604d92318..215051f8834fc94eb9e32b508f34b13626ac9349 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -41,6 +43,12 @@ bool AllOperandsAreConstants(const HloInstruction& instruction); // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); +// Determines whether the given computation contains an instruction with one of +// the given opcodes. Checks both comp's instructions and the instructions of +// any computations nested within it. +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes); + // Returns an operand of an instruction with the given opcode. If there are // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 961930f0a888e90f86e4354fa1373a303af8ec2f..4aa8067752481ffab29e1a573ffa49d4aa046f1f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { @@ -22,7 +24,7 @@ HloReachabilityMap::HloReachabilityMap( : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { - indices_[hlo] = bit_vectors_.size(); + indices_[GetKey(hlo)] = bit_vectors_.size(); bit_vectors_.emplace_back(size_); } CHECK_EQ(size_, indices_.size()); // instructions should be unique @@ -71,4 +73,70 @@ bool HloReachabilityMap::IsConnected(const HloInstruction* a, return IsReachable(a, b) || IsReachable(b, a); } +std::unique_ptr HloReachabilityMap::Build( + const HloComputation* computation) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + auto channel_dependency_map = computation->ComputeChannelDependencies(); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + +void HloReachabilityMap::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction) { + std::queue worklist; + worklist.push(instruction); + + std::vector inputs; + + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); + + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); + + if (SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 5a5f01f8fd647c74217c80ce4a7633b8957e335f..7823b06a41b3052f6f50f7ffa358de5b23ba679f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -16,27 +16,30 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#include #include #include +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace xla { -class HloInstruction; - // A class for representing reachability between HloInstructions. // -// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix -// and it is up to the user of the class to set the adjacency matrix such that -// it represents reachability, i.e. such that it is transitive. That the graph -// be transitive is thus not an invariant of this class, but it is required for -// the name of the class and its methods to make sense. +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given @@ -44,6 +47,15 @@ class HloReachabilityMap { explicit HloReachabilityMap( absl::Span instructions); + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true @@ -70,6 +82,10 @@ class HloReachabilityMap { // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + // Returns true if "b" is reachable from "a" // // Note that this function only correctly answers queries about reachability @@ -82,6 +98,11 @@ class HloReachabilityMap { // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + // Checks if an instruction is in the Reachability map. + bool IsPresent(const HloInstruction* a) const { + return indices_.contains(GetKey(a)); + } + private: // A bit-vector implementation specialized for this use case which provides a // fast bitwise OR operation not available in tensorflow::gtl::BitMap. @@ -143,18 +164,24 @@ class HloReachabilityMap { absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); + uint64 GetKey(const HloInstruction* instruction) const { + uint64 unique_id = absl::bit_cast(instruction->unique_id()); + uint64 module_id = + absl::bit_cast(instruction->parent()->parent()->unique_id()); + return (module_id << 32) | unique_id; + } // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { - return FindOrDie(indices_, instruction); + return FindOrDie(indices_, GetKey(instruction)); } // The number of instructions in the reachability map. const size_t size_; - // Dense assignment from HloInstruction* to number. These numbers index - // into the bit_vectors_ vector and into the bits within a BitVector. - absl::flat_hash_map indices_; + // Dense assignment from HloInstruction::unique_id to number. These numbers + // index into the bit_vectors_ vector and into the bits within a BitVector. + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index d9848cee0bfa904a90aea4626c3ee62c2cbb45b6..595176709806d54fc7c7c5ea301654717096b2d6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloVerifiedTestBase {}; +class HloReachabilityTest : public HloTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: @@ -48,7 +48,8 @@ TEST_F(HloReachabilityTest, Reachability) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); HloReachabilityMap reachability({a, b, c, d, e}); reachability.SetReachable(a, a); @@ -81,6 +82,130 @@ TEST_F(HloReachabilityTest, Reachability) { EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); } +TEST_F(HloReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloReachabilityMap::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + reachability->UpdateReachabilityThroughInstruction(exp); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + reachability->UpdateReachabilityThroughInstruction(negate); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + +TEST_F(HloReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloReachabilityMap::Build(computation); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 49e46ecd00ee4370f3e93746348373b79febed3d..48add75523f02005c70bc6baf69a6b7d5aa4f7ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector; // before arbitrary elements. class InstructionList { public: - explicit InstructionList(const std::vector& order) { + explicit InstructionList(const HloInstructionSequence& order) { int64 position = 0; Item* last = nullptr; - for (const HloInstruction* inst : order) { + for (HloInstruction* inst : order.instructions()) { // Add a new item to the linked list. Item* item = new Item; item->next = nullptr; @@ -151,7 +151,7 @@ class InstructionList { // to be monotonically increasing through the list, and so is still useful // for quickly(-ish) determining the order of arbitrary instructions in // the list. - item->instruction = const_cast(inst); + item->instruction = inst; item->position = position; position++; @@ -927,7 +927,7 @@ Item* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, - const std::vector& order) const { + const HloInstructionSequence& order) const { InstructionList instruction_list(order); MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, instruction_list); @@ -971,8 +971,7 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list( - schedule->sequence(computation).instructions()); + InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1184,7 +1183,7 @@ StatusOr HloRematerialization::RematerializeComputation( sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { - const HloInstruction* instruction = item->instruction; + HloInstruction* instruction = item->instruction; sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1235,10 +1234,8 @@ StatusOr HloRematerialization::Run(HloModule* module) { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - module->schedule() - .sequence(node.computation()) - .instructions())); + ComputePeakMemory(node.computation(), module->schedule().sequence( + node.computation()))); } return Status::OK(); }, diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 70d83c04f07ca7fd0139f586869e8fe688f958f4..a07d348041b72bba45c6fd1f726f2a0065d01e53 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass { // peak memory is the maximum total size of all live HLO instruction values at // any program point. 'order' is the order in which the HLO instructions will // be emitted which is used to determine lifespans of HLO values. - StatusOr ComputePeakMemory( - const HloComputation* computation, - const std::vector& order) const; + StatusOr ComputePeakMemory(const HloComputation* computation, + const HloInstructionSequence& order) const; // Returns the peak memory usage of the called computations for the given // instruction. Zero is returned if the instruction calls no computations. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f7e82fb1f88e856305f6f481a451d4cd64ba4acf..22c3c40a93a1ddcd36659483fcc79fede32dd2c3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloVerifiedTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -162,7 +162,7 @@ class HloRematerializationTest : public HloVerifiedTestBase { // Test rematerialization of a single computation produced by // MakeRematerializableComputation. TEST_F(HloRematerializationTest, SingleComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, module)); + /*memory_limit_bytes=*/14 * 1024, module.get())); EXPECT_TRUE(changed); // Root should not have changed. @@ -203,7 +203,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, module)); + /*memory_limit_bytes=*/20 * 1024, module.get())); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -225,7 +225,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. TEST_F(HloRematerializationTest, RematerializeAroundWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // bit lower (17KB) to force rematerialization of the entry computation. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, module)); + /*memory_limit_bytes=*/17 * 1024, module.get())); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -261,7 +261,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while. Both the entry computation and while body computation should have // computations rematerialized. TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, module)); + /*memory_limit_bytes=*/15 * 1024, module.get())); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -293,7 +293,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. TEST_F(HloRematerializationTest, RematerializeNestedComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // ~12K so pick something slightly larger. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, module)); + /*memory_limit_bytes=*/13 * 1024, module.get())); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -346,7 +346,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + // // rng + tanh + exp - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( @@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -420,7 +420,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // The value %bcast is live across each call of Subcomputation (which requires // 8KB) though the value is not used in the calls. Rematerializing %bcast // across these calls reduces peak memory use from ~20KB down to ~16KB. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -533,7 +533,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // (ie %bcast is used indirectly by %negate), otherwise the %negate operand // aliases %add_2. const bool indirectly_used = GetParam(); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 0778ff52174ef89c476950f2c830268a63888382..8f6eb974c5179b420c8f961393ca923e0a3b3530 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -46,8 +46,8 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -81,9 +81,8 @@ StatusOr HloSchedule::ToProto() const { return std::move(proto); } -void HloSchedule::set_sequence( - const HloComputation* computation, - absl::Span sequence) { +void HloSchedule::set_sequence(const HloComputation* computation, + absl::Span sequence) { set_sequence(computation, HloInstructionSequence(sequence)); } @@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } @@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - absl::flat_hash_map> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's @@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule( // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; + std::queue worklist; - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { if (ids_in_schedule.count(instruction->unique_id()) == 0) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { @@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule( // Lambda which schedules all instructions on the worklist. auto schedule_worklist = [&]() { while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.front(); worklist.pop(); new_sequence.push_back(instruction); - std::vector* new_users = + std::vector* new_users = tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); if (new_users != nullptr) { // This just-scheduled instruction has users which are newly added to // the module. Update the number of unscheduled operands and push the // newly added instruction to the worklist if it is ready to // schedule. - for (const HloInstruction* new_user : *new_users) { + for (HloInstruction* new_user : *new_users) { unscheduled_operand_count.at(new_user)--; CHECK_GE(unscheduled_operand_count.at(new_user), 0); if (unscheduled_operand_count.at(new_user) == 0) { @@ -264,7 +263,10 @@ Status HloSchedule::Verify() const { } TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); + computation->instruction_count()) + << "Schedule for computation " << computation->name() << " has " + << instruction_position.size() << " instructions, expected " + << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(instruction_position.count(instruction) == 1) << "Instruction " << instruction->name() << " is not in schedule"; diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 0a714101ee587aa847fa674bbde5586287c51f33..486ddbf499de80c634bc497158cd79ca066cc866 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -35,14 +35,14 @@ class HloInstructionSequence { public: HloInstructionSequence() = default; explicit HloInstructionSequence( - absl::Span instructions) { - for (const HloInstruction* instruction : instructions) { + absl::Span instructions) { + for (HloInstruction* instruction : instructions) { push_back(instruction); } } // Adds the instruction to the end of the sequence. - void push_back(const HloInstruction* instruction) { + void push_back(HloInstruction* instruction) { instruction_sequence_.push_back(instruction); id_sequence_.push_back(instruction->unique_id()); } @@ -56,7 +56,7 @@ class HloInstructionSequence { int64 size() const { return instruction_sequence_.size(); } // Returns the sequence of HLO instructions. - const std::vector& instructions() const { + const std::vector& instructions() const { return instruction_sequence_; } @@ -65,7 +65,7 @@ class HloInstructionSequence { private: // The sequence as HloInstructions. - std::vector instruction_sequence_; + std::vector instruction_sequence_; // The sequence of HLO instructions, represented by their unique IDs. The // sequence is stored as both HloInstructions and unique IDs because the @@ -98,7 +98,7 @@ class HloSchedule { // Sets the sequence for the given computation to the given sequence. void set_sequence(const HloComputation* computation, - absl::Span sequence); + absl::Span sequence); void set_sequence(const HloComputation* computation, HloInstructionSequence sequence); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 1424569ac1f62e4b965876141f1eb40be4f15bea..0e56e6f760e35ddcb45c6f58771d78405a09acfe 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -56,10 +56,10 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); - const std::vector& entry_schedule = + const auto& entry_schedule = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(entry_schedule.size(), 6); @@ -90,7 +90,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -139,7 +139,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -183,7 +183,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 45c684d66752862eec301b8943d350804f070309..c1073911ea9dc3811c195e27bcbae9b00929ad17 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -143,7 +143,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { // Do not unify subcomputations with different parameter shapes. TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f..1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloVerifiedTestBase { +class HloTfGraphBuilderTest : public HloTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index b6670d409b92e8be42f5cdb40fba8d662ae83958..1f01b0bb365450a933da9cc443db5223c06903f0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -166,9 +166,6 @@ class HloValue : public BufferValue { // Whether this value is live out of the HLO module. bool live_out_of_module_ = false; - - // Whether this value is live out of its computation. - bool live_out_of_computation_ = false; }; std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 136824a33565d65663f1e484713c5180a762b25b..60d8a511b5743d4f342a2cc3a7c91c71acdbeaf8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -755,6 +756,12 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } +Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { + return CheckShape(get_size, + ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimension())); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -1331,6 +1338,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + if (crs->all_reduce_id().has_value()) { + TF_RET_CHECK(crs->all_reduce_id().value() > 0) + << "All reduce id must be greater than 0 for " + << crs->ToShortString(); + } + return Status::OK(); + } + Status Preprocess(HloInstruction* instruction) override { auto previous = instructions_by_name_.find(instruction->name()); TF_RET_CHECK(previous == instructions_by_name_.end()) @@ -1410,6 +1426,8 @@ StatusOr HloVerifier::Run(HloModule* module) { return target_metadata_->ShapeSize(shape); })); + TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 83b6244d1be0e1eec66daabfcfd1be5a3c0131ac..9fbfd6a21c1f1148801000169046fbcbb37934fe 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -94,6 +94,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleGather(HloInstruction* gather) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleGetDimensionSize(HloInstruction* get_size) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index afe01e5487c3225815e01343d86e9fe894c2cde8..4bc557e4e62e7df4e25fda86fe417e84129b464c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,7 +35,11 @@ namespace { using ::testing::HasSubstr; -// This class cannot be converted to use HloVerifiedTestBase. It explicitly +std::unique_ptr CreateUnverifiedModule() { + return absl::make_unique("module", HloModuleConfig()); +} + +// This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: @@ -66,7 +70,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -85,7 +89,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -104,7 +108,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -138,7 +142,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -303,7 +307,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -327,7 +331,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index e103222b55faccf2d0286dce33c0f1ce5df01feb..90904ac00110457bcc3b8974816a7080c4ab89fc 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -90,20 +90,29 @@ string HumanReadableProfileBuilder::ToString() const { op.optimal_seconds < 0 ? "" : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), - op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), - op.transcendental_count <= 0 - ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + op.flop_count > 0 && nsecs > 0 + ? HumanReadableNumFlops(op.flop_count, nsecs) + : "", + op.transcendental_count > 0 && nsecs > 0 + ? HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + : "", bytes_per_sec, bytes_per_cycle, op.name); }; - float optimal_seconds_sum = 0.0; + double optimal_seconds_sum = 0; int64 total_flops = 0.; int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { if (op.optimal_seconds > 0) { - optimal_seconds_sum += op.optimal_seconds; + // An op can run faster than the estimated optimum. For example, we might + // estimate a fusion's speed by looking at the size of its operands and + // result, but perhaps the fusion doesn't read the entirety of all of its + // inputs. For the purposes of summing the instructions' optimal speeds, + // we treat the "optimum" as the smallest of either the estimated optimum + // and the actual speed. + optimal_seconds_sum += + std::min(double{op.optimal_seconds}, CyclesToSeconds(op.cycles)); } total_flops += std::max(op.flop_count, int64{0}); total_transcendentals += std::max(op.transcendental_count, int64{0}); @@ -114,7 +123,7 @@ string HumanReadableProfileBuilder::ToString() const { print_op({is_entry_computation_ ? "[total] [entry]" : "[total]", "[total]", /*category=*/"", total_cycles_, total_flops, total_transcendentals, - total_bytes, optimal_seconds_sum}, + total_bytes, static_cast(optimal_seconds_sum)}, /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. @@ -155,8 +164,10 @@ string HumanReadableProfileBuilder::ToString() const { entry.text = op.name; entry.short_text = op.short_name; entry.category_text = op.category; - entry.metric = - CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6; + // Ignore ops that run faster than the estimated optimal here, as we do + // when calculating optimal_seconds_sum. + entry.metric = std::max( + 0., CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6); total_discrepancy_in_microseconds += entry.metric; table.AddEntry(std::move(entry)); } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d5225b8012b68f851b2bfec219d736ba0d..cf6cf897fe11eda01ba6b22119bba34ac2bef8fe 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -18,19 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { +class ImplicitBroadcastRemoverTest : public HloTestBase { protected: ImplicitBroadcastRemover remover_; }; TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -41,15 +42,16 @@ TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -60,13 +62,13 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); @@ -76,6 +78,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -86,9 +89,9 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Subtract(op::Parameter(), @@ -98,6 +101,7 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); @@ -108,9 +112,9 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, @@ -120,6 +124,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { } TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); @@ -132,9 +137,9 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), @@ -147,6 +152,7 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { TEST_F(ImplicitBroadcastRemoverTest, TernaryScalarAndDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -159,9 +165,9 @@ TEST_F(ImplicitBroadcastRemoverTest, builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 2d03aebc1aca4c55cca588072233b7a18e70a306..98246d5403e4aebc2f4d81e52145706355ddd9a9 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloVerifiedTestBase { +class IndexedArrayAnalysisTest : public HloTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -61,12 +61,12 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { const string& root_expression, bool print_constants) { IndexedArrayAnalysis indexed_tensor_analysis; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN( - IndexedArrayAnalysis::Array* const array_result, - indexed_tensor_analysis.GetArrayFor( - module().entry_computation()->root_instruction())); + TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + m->entry_computation()->root_instruction())); string string_result = CanonicalizeWhitespace( indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; @@ -481,8 +481,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { - { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, - { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } }) (reshape %indices to s32[]) 0->[]) )"; @@ -512,8 +512,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,6] s32[2,1,1,6] { - { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, - { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } }, + { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } }) (reshape %indices to s32[5]) 0->[2]) )"; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 69a4c160ee5c4539272c3085338dc6de1b9347ff..7f2d7e7cffc6debaaf9b64fffc5a8a7037ecdaa3 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -26,7 +26,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -153,6 +155,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGetDimensionSize: return true; } @@ -437,8 +440,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { } // namespace std::unique_ptr InstructionFusion::GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer) { + HloComputation* computation) { return absl::make_unique(computation); } @@ -451,14 +453,11 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); - auto fusion_queue = - GetFusionQueue(computation_, [&](HloInstruction* producer) { - return do_not_duplicate.count(producer) > 0; - }); + auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -489,9 +488,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { HloInstruction* fusion_instruction; // Try "regular" fusion if the operand may be duplicated. Otherwise, // perform multi-output fusion, unless this creates a cycle. - // TODO(tjoerg): Consider making multi-output fusion the default. - if (ShouldFuse(instruction, i) && - do_not_duplicate.count(operand) == 0) { + if (do_not_duplicate.count(operand) == 0 && + ShouldFuse(instruction, i)) { fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && @@ -565,15 +563,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return absl::c_any_of( - consumer->operands(), [&](const HloInstruction* consumer_operand) { - // The fusion algorithm traverses the HLO graph in reverse post order. - // Thus `cosumers` is visited before its operands (including - // `producer`). Therefore, consumer operands cannot have been fused yet. - // It is thus safe to use the pre-computed reachability map. - return consumer_operand != producer && - reachability_->IsReachable(producer, consumer_operand); - }); + auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { + // A consumer operand may have been multii-output fused into a parallel + // consumer and thus be missing from the oridinal reachability map. + if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { + reachability_ = HloReachabilityMap::Build(consumer->parent()); + } + return reachability_->IsReachable(a, b); + }; + return absl::c_any_of(consumer->operands(), + [&](const HloInstruction* consumer_operand) { + return consumer_operand != producer && + is_reachable(producer, consumer_operand); + }); } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f14c6675208c72112aea0179c238b58709d625b5..198bd7fce5f392e5e895b959523d4fe9cf208ba2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -54,8 +55,7 @@ class InstructionFusion : public HloModulePass { // fused. The default implementation processes consumers in reverse post // order. virtual std::unique_ptr GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer); + HloComputation* computation); // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -111,6 +111,10 @@ class InstructionFusion : public HloModulePass { return is_expensive_(instruction); } + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; @@ -145,10 +149,6 @@ class InstructionFusion : public HloModulePass { // duplicated. std::function is_expensive_; - // Whether multi-output fusion would introduce a cycle into the HLO graph. - bool MultiOutputFusionCreatesCycle(HloInstruction* producer, - HloInstruction* consumer); - // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index da1ad90959dc0ab1a840b3390281ce9d4999651e..6b483126499fe1e635a7d13cf597ec5d089c5b24 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( @@ -361,7 +361,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary2, computation->root_instruction()); EXPECT_TRUE( @@ -385,7 +385,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index a06d6113e84630df14ff68280c248cccb9afaf06..7635fbfed6f6a51fc9d203251d9bebf43cc63fd9 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -37,7 +37,7 @@ namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, /*hlo_profile_index_map=*/nullptr), diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 3b1ebce0c75457d65e6834c809fe488a9c4a159a..bda13d376360306c81230e41b01cefc6caff230d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -42,7 +42,7 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module, + InterpreterExecutable(std::unique_ptr hlo_module, std::unique_ptr evaluator); ~InterpreterExecutable() override; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index c9b40d3c6195f80a19272a0d98890049d02315b9..b0fc1af8b89d7327a00f77f471e90d143a92de7c 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -110,3 +110,5 @@ REGISTER_MODULE_INITIALIZER( // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, + interpreter_platform); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6b03394669858ef0ffdbdd1a2bad90e9df9fbcd9..a90411922205c0006159ff99f35a70138b1bee4f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2092,6 +2092,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 47bfca2fd6e1527b73e396151d3764867ac03697..61d8a0a4e6aa39e2e921acae1c65df1b3c329e46 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,20 +49,19 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloVerifiedTestBase { +class LayoutAssignmentTest : public HloTestBase { protected: - void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout, + void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, /*channel_constraints=*/channel_constraints); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m).status()); } - std::vector LayoutOf(HloModule* module, absl::string_view name) { + std::vector LayoutOf(HloModule* m, absl::string_view name) { auto minor_to_major = - FindInstruction(module, name)->shape().layout().minor_to_major(); + FindInstruction(m, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } @@ -91,7 +90,7 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { TEST_F(LayoutAssignmentTest, ComputationLayout) { // Verify the layouts of the root and parameter instructions of a computation // match the ComputationLayout for two different layouts. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); @@ -101,8 +100,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout layout = LayoutUtil::MakeLayout(minor_to_major); Shape shape(ashape); @@ -113,7 +112,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -131,8 +130,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); Shape col_major_shape(ashape); @@ -149,7 +148,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -160,7 +159,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { // Verify that the layout of the fused parameters in a fusion instruction // match that of the fusion operands. Other fused instructions should have no // layout. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); auto constant_literal1 = LiteralUtil::CreateR2WithLayout( @@ -180,8 +179,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); @@ -194,7 +193,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -229,13 +228,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { auto negate = builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kNegate, get_element0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -267,17 +266,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -302,11 +301,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { auto nested_tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, inner_tuple})); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = nested_tuple->shape(); *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); @@ -316,7 +315,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -329,12 +328,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // %tuple.1 = Tuple(%copy) layout=({0,1}) // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) // - EXPECT_TRUE( - AlgebraicSimplifier(/*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return false; }) - .Run(module) - .ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), @@ -361,9 +359,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -374,7 +371,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -403,8 +400,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { HloInstruction::CreateTranspose(bshape, log, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + auto computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -415,7 +412,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -439,9 +436,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { HloInstruction::CreateBroadcast(bshape, param, {1, 2})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); Shape input_shape_with_layout(ashape); Shape output_shape_with_layout(cshape); @@ -454,7 +451,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -488,9 +485,8 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tuple)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tuple)); ComputationLayout computation_layout(computation->ComputeProgramShape()); Shape param_shape_with_layout(f32_4); @@ -507,7 +503,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -558,9 +554,8 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); auto reshape = builder.AddInstruction( HloInstruction::CreateReshape(cshape, concatenate)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(reshape)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(reshape)); Shape param0_shape_with_layout(ashape); Shape param1_shape_with_layout(ashape); @@ -573,7 +568,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -593,11 +588,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -611,11 +606,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -681,12 +676,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - ParseAndVerifyModule(module_str); - + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); @@ -721,9 +716,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -735,19 +731,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -757,7 +753,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); @@ -784,7 +780,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { true_builder.AddInstruction(HloInstruction::CreateTuple({add})); } HloComputation* true_computation = - module->AddEmbeddedComputation(true_builder.Build()); + m->AddEmbeddedComputation(true_builder.Build()); auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); { @@ -800,14 +796,14 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = - module->AddEmbeddedComputation(false_builder.Build()); + m->AddEmbeddedComputation(false_builder.Build()); builder.AddInstruction(HloInstruction::CreateConditional( result_tshape, pred, tuple, true_computation, tuple, false_computation)); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -828,13 +824,13 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module).status(); + Status error_status = layout_assignment.Run(m.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -861,9 +857,10 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -873,12 +870,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(&module(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } @@ -897,17 +894,17 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] cross-replica-sum(gte), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -917,12 +914,12 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); - const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); } @@ -938,11 +935,12 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -966,11 +964,12 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -997,11 +996,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1028,11 +1028,12 @@ TEST_F(LayoutAssignmentTest, } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1050,11 +1051,12 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1107,20 +1109,21 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); // Sanity check to verify that there's a layout mismatch. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Make sure that layout assignment did not magically eliminate the mismatch, // in which case the test didn't prove anything. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); } TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { @@ -1136,32 +1139,32 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { // and result layout should match that of the computation. { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); } { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); @@ -1179,24 +1182,24 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // The custom call should be partially encapsulated in kCopy instructions // because of the layout mismatches. - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); @@ -1211,18 +1214,18 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall())); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); } @@ -1238,25 +1241,25 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Tuple()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); } @@ -1273,36 +1276,34 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, // Try with a couple different layouts. In each case the custom calls operand // and result layout should match that of the computation. TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ExpectTupleLayoutIs(module->result_shape(), {{1, 0}, {1, 0}}); + ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}}); - const HloInstruction* custom_call = - FindInstruction(module.get(), "custom-call"); + const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call"); ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } Status AssignLayoutsToComputation( - HloModule* module, - ChannelLayoutConstraints* channel_constraints = nullptr) { - if (!module->entry_computation_layout().result_layout().LayoutIsSet()) { - module->mutable_entry_computation_layout() + HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { + if (!m->entry_computation_layout().result_layout().LayoutIsSet()) { + m->mutable_entry_computation_layout() ->mutable_result_layout() ->SetToDefaultLayout(); } LayoutAssignment layout_assignment( - module->mutable_entry_computation_layout(), + m->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout, channel_constraints); - return layout_assignment.Run(module).status(); + return layout_assignment.Run(m).status(); } TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { @@ -1325,16 +1326,16 @@ TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { auto add = b.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1)); b.AddInstruction(HloInstruction::CreateTuple({add, transpose})); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(b.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(b.Build()); Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0}); Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1}); - *module->mutable_entry_computation_layout()->mutable_result_layout() = + *m->mutable_entry_computation_layout()->mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor})); const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0}); - ForceParameterLayout(module.get(), 0, r2_dim0major); - ForceParameterLayout(module.get(), 1, r2_dim0major); - TF_ASSERT_OK(AssignLayoutsToComputation(module.get())); + ForceParameterLayout(m.get(), 0, r2_dim0major); + ForceParameterLayout(m.get(), 1, r2_dim0major); + TF_ASSERT_OK(AssignLayoutsToComputation(m.get())); EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0)); EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(), diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 850501a4b5c521f5a5cc29658a04ae4bb638e14f..728a66b388f0f9af480ff88b5e96990a26e36af5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", ], @@ -197,14 +198,17 @@ cc_library( hdrs = ["sort_util.h"], deps = [ ":ir_array", + ":kernel_support_library", ":llvm_loop", ":llvm_util", ":loop_emitter", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc..1540a40ef820f483c27b3d0d81d24ebb265847b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -108,6 +108,14 @@ class IrArray { Index(absl::Span multidim, llvm::Value* linear, const Shape& shape); + // Returns an index that adds `addend` to the given `dim` of the object. + Index AddOffsetToDim(llvm::Value* addend, int64 dim, + llvm::IRBuilder<>* b) const { + IrArray::Index index = *this; + index[dim] = b->CreateAdd(index[dim], addend); + return index; + } + const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738..c26711e526c9b89cdedcb6aed9f93d41dd25dc83 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -52,6 +52,29 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), dimensions); } + +// Given an index for a shape, return the equivalent new index if the shape is +// reshaped to another shape. +IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, + const Shape& reshaped_shape, + llvm::IRBuilder<>* b) { + auto bounds = shape.dimensions(); + auto minor_to_major = shape.layout().minor_to_major(); + llvm::Value* linear_index = index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = b->CreateMul( + index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, reshaped_shape, b); +} + } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -60,28 +83,30 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + std::vector permutation(a.dimensions().size()); + absl::Span minor_to_major_a = LayoutUtil::MinorToMajor(a); + std::vector major_to_minor_a(minor_to_major_a.rbegin(), + minor_to_major_a.rend()); + absl::Span minor_to_major_b = LayoutUtil::MinorToMajor(b); + std::vector major_to_minor_b(minor_to_major_b.rbegin(), + minor_to_major_b.rend()); + for (size_t i = 0; i < permutation.size(); ++i) { + permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]); } - auto segs = ConsecutiveSegments(perm); - if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { - Shape norm_a = + + std::vector segments = ConsecutiveSegments(permutation); + if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) { + Shape descending_layout_shape = ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape reduced_a = MergeDimensions(segs, norm_a); - auto reduced_a_dims = reduced_a.dimensions(); + Shape normalized_shape = MergeDimensions(segments, descending_layout_shape); + absl::Span normalized_dims = + AsInt64Slice(normalized_shape.dimensions()); std::vector dims_021; - if (2 == segs.size()) { + if (2 == segments.size()) { // The logical component-0 is of size one. - dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + dims_021 = {1, normalized_dims[1], normalized_dims[0]}; } else { - dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]}; } return dims_021; @@ -90,27 +115,117 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b) { - auto bounds = reduced_output_shape.dimensions(); - auto minor_to_major = reduced_output_shape.layout().minor_to_major(); - llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); - int64 multiplier = 1; - for (int i = 0; i < reduced_output_index.size(); ++i) { - int64 dim = minor_to_major[i]; - llvm::Value* addend = - b->CreateMul(reduced_output_index[dim], - reduced_output_index.GetConstantWithIndexType(multiplier), - "linearizing", - /*HasNUW=*/true, /*HasNSW=*/true); - linear_index = b->CreateAdd(linear_index, addend, "", - /*HasNUW=*/true, /*HasNSW=*/true); - multiplier *= bounds[dim]; +KernelMappingScheme::KernelMappingScheme( + absl::Span dims_in_elems, int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, int64 num_threads_y, + int64 num_threads_x, llvm::IRBuilder<>* b) + : b_(b), + dims_in_elems_(dims_in_elems), + tile_sizes_{1, tile_size_y, tile_size_x}, + num_threads_x_(num_threads_x), + num_threads_y_(num_threads_y) { + DCHECK_EQ(dims_in_elems_.size(), 3); + DCHECK_EQ(req_block_sizes.size(), 3); + + DCHECK_EQ(tile_size_y % num_threads_y_, 0); + DCHECK_EQ(tile_size_x % num_threads_x_, 0); + + dims_in_tiles_ = ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_); + block_sizes_.reserve(req_block_sizes.size()); + absl::c_transform(req_block_sizes, dims_in_tiles_, + std::back_inserter(block_sizes_), + [](const int64 requested_size, const int64 max_size) { + return std::min(requested_size, max_size); + }); + dims_in_blocks_ = ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_); + + VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; + VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; + VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") + << "]"; +} + +IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape) { + DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); + Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + unnormalized_shape.element_type(), GetDimensionsInElements()); + return GetReshapedIndex(normalized_shape_index, output_shape, + unnormalized_shape, b_); +} + +IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), + llvm::cast(block_id)); + llvm::Value* linear_block_id = + b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); + return IrArray::Index(linear_block_id, + ShapeUtil::MakeShapeWithDescendingLayout( + PRED /*arbitrary*/, dims_in_blocks_), + b_); +} + +IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( + const IrArray::Index& block_index) { + IrArray::Index tile_index = block_index; + for (int i = 0; i < block_sizes_.size(); ++i) { + tile_index[i] = b_->CreateMul( + block_index[i], + llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), + "block_origin." + std::to_string(i)); + } + return tile_index; +} + +IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( + const IrArray::Index& tile_index) { + IrArray::Index elem_index = tile_index; + for (int i = DimY; i < DimTot; ++i) { + elem_index[i] = + b_->CreateMul(tile_index[i], + llvm::ConstantInt::get(tile_index[i]->getType(), + GetTileSizeForDimension(i)), + "tile_origin." + std::to_string(i)); } + return elem_index; +} + +llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name) { + // If shared memory tranpose is needed, we use square tiles. + CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY()); + + // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is + // organized into 32-way. We usually use the warp size or a multiplier or a + // the warp size as the size for tiling. This may cause all elements in the + // same column of a tile use the same memory bank and therefore shared memory + // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer + // can reduce such shared memory bank conflicts. + llvm::Type* buffer_type = llvm::ArrayType::get( + llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1), + GetTileSizeForDimension(DimY)); + return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(), + buffer_type, buffer_name); +} - return IrArray::Index(linear_index, unreduced_output_shape, b); +std::tuple +KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { + // Calculate (y, x) coordinate of the thread in the 2D view of thread block + // defined by (num_thread_y, num_thread_x) from thread_id. + llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm::Value* thread_id_int = + b_->CreateIntCast(thread_id_raw, index_ty, + /*isSigned=*/true, "thread.id.x"); + llvm::Value* num_thread_x = + llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + return std::make_tuple(y, x); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 5ea05b3188a1c0881e4c0c41625d530aff1b1205..06002d57b0d7daa07f903feebe67a60a083c0e7c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -28,23 +28,160 @@ namespace llvm_ir { // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical // components to 0-2-1. We call the shape being transposed the input shape and -// the transposed shape the output shape. The logical view of the input and -// output shapes for the transpose are called the 0-1-2 shape or reduced input -// shape and the 0-2-1 shape or the reduced output shape respectively. The -// original input and output shapes are called the unreduced input and output -// shapes. - +// the transposed shape the output shape. The logical view of the input/output +// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized +// shapes. The original input/output shapes are called unnormalized shapes. +// // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the -// reduced shape of `b` or the 0-2-1 shape. +// normalized shape of `b` or the 0-2-1 shape. absl::optional > FindTranspose021(const Shape& a, const Shape& b); -// Return the unreduced output index corresponding to the given reduced output -// index. -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b); +// A tile is a spatial subdivision of a tensor. We group tensor elements into +// tiles so that we can launch kernels to process the tensor elements in blocks +// of tiles. +// +// A kernel mapping scheme describes a method to partition the tensors accessed +// by an unnested HLO instruction into tiles and blocks of tiles, and the +// associated information to use hardware threads to process the tensor elements +// in blocks of tiles. +// +// Currently, there are two main use cases for a tiling scheme. First, we +// implement kernels with 0-2-1 memory transpose using shared memory to improve +// memory access pattern. Second, we implement reduction to contiguous +// dimensions in layout, with or without memory tranpsose, to achieve better +// memory access pattern as well as to reduce the need numbers of executed +// expensive instructions, such as thread synchronization related instructions +// and atomic operations. For both use cases, we can apply a normalization to +// the original tensors, to collapse contiguous dimensions for the same purpose +// and produce normlized three dimensional tensors. For this reason, the tiling +// scheme class only needs to handle normalized three dimensional tensors and +// two dimensional tiles. +// +// The current implementation of the class is somewhat NVIDIA GPU oriented. This +// situation can be improved when there is a need though. The idea of 0-2-1 +// transpose using shared memory can be found in the following CUDA algorithm in +// TensorFlow: https://goo.gl/MStRV6. +// +// We use a thread block to process a tile because we want to use the HW thread +// block synchronization primitives to synchronize the processing of all the +// elements in the same tile. A thread block can be viewed as a two dimensional +// array of threads, described by the number of threads for the Y and X +// dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of +// (tile_size_y, tile_size_x) as follows: each thread in the thread block +// processes one element in the tile so that all the threads in the thread block +// together process a subdivision of the tile that has the same dimension as the +// thread block array. Then the thread block moves on to process the next +// subdivision of the tile until the whole tile is processed. Therefore, each +// thread in the thread block processes +// tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. +// +// There are situations where we want a thread block to process multiple +// tiles. We can't group those tiles into a bigger tiles because we limit a tile +// to a two dimensional spatial subdivision of a tensor. For example, when we +// use tiling to implement reduction with tranpose, we want the partial sum +// produced by each thread to accumulate values for more elements before using +// shlf_down and atomic_add instructions for further reduction, to amortize the +// cost of such expensive instructions. The concept of tile block is introduced +// for this purpose. A tile block is a three dimensional array of tiles, of +// which some dimensions may be degenerated to only one tile. +class KernelMappingScheme { + public: + enum { DimZ = 0, DimY, DimX, DimTot }; + + public: + // dims_in_elems: the normalized tensor dimensions. + // req_block_sizes: the requested block size in number of tiles for each + // dimension. The actual block size is set to min(req_block_size, + // dims_in_number_of_blocks). + explicit KernelMappingScheme(absl::Span dims_in_elems, + int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); + + absl::Span GetDimensionsInElements() const { + return dims_in_elems_; + } + absl::Span GetDimensionsInTiles() const { + return dims_in_tiles_; + } + absl::Span GetDimensionsInBlocks() const { + return dims_in_blocks_; + } + + int64 GetNumberOfTilesInTotal() const { + return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); + } + int64 GetNumberOfTilesInOneBlock() const { + return absl::c_accumulate(block_sizes_, 1, std::multiplies()); + } + + int64 GetNumberOfBlocks() const { + return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); + } + + int64 GetTileSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return tile_sizes_[d]; + } + int64 GetTileSizeForDimensionX() const { + return GetTileSizeForDimension(DimX); + } + int64 GetTileSizeForDimensionY() const { + return GetTileSizeForDimension(DimY); + } + + absl::Span GetBlockSizes() const { return block_sizes_; } + + int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } + int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } + + int64 GetThreadsPerTile() const { + return GetNumberOfThreadsForDimensionX() * + GetNumberOfThreadsForDimensionY(); + } + + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); + // Returns the index for the first tile in the block with the given block + // index. + IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index); + // Returns the index for the first element in the tile with the given tile + // index. + IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index); + + std::tuple EmitThreadYXCoordinate( + llvm::Type* index_ty); + + IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape); + + llvm::GlobalVariable* GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name); + + private: + llvm::IRBuilder<>* b_; + // The number of elements in each dimension. + absl::Span dims_in_elems_; + + // The number of elements for each dimension of a tile. + std::vector tile_sizes_; + // The number of tiles in each dimension. It is computed from dims_in_elem_ + // and tile_sizes_. + std::vector dims_in_tiles_; + + // The number of tiles for each dimension of a tile block. + std::vector block_sizes_; + // The number of blocks in each dimension of a tile block. It is computed from + // dims_in_tile_ and block_sizes_. + std::vector dims_in_blocks_; + + // Number of threads used to process elements in the X direction of a tile. + int64 num_threads_x_; + // Number of threads used to process elements in the Y direction of a tile. + int64 num_threads_y_; +}; // A class to represent information for tiled parameters to support IR emission // for 021 transpose. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2e5aebb74c29b809ae5c323b1912043d9f160d67..df78726166eea953b57e72a5a5fc81ee246aca34 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -83,10 +84,9 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b) { +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -260,6 +260,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, /*AddNull=*/false); } +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name) { + const int kNVPTXSharedMemoryAddrSpace = 3; + return new llvm::GlobalVariable( + *module, tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); +} + llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index f59baff263fe7184c6b0821c9dbd9eee205586a6..c604c7c870adf734a29017e6accbd159317a9548 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -101,10 +102,9 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b); +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -155,6 +155,11 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module); +// Allocates a tile of shared memory. +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name); + // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point // of the builder is set to the same place after calling this function diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 05ba4a40da413f0e774214e55ef69d023afc48e2..e22c2173c271fc9571be1ddb0759d2b31562dc98 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -18,7 +18,9 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -28,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -39,147 +43,365 @@ namespace xla { namespace llvm_ir { namespace { -// Adds the inner comparison loop where we compare elements pointed to by -// 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, - const IrArray::Index& compare_keys_index, - const IrArray& keys_array, - const std::vector& values_arrays, - llvm::IRBuilder<>* b) { - // if (is_smaller_index && - // compare_keys[dimension_to_sort] < dimension_to_sort_bound) - llvm::Value* is_smaller_index = b->CreateICmpSLT( - keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = EmitIfThenElse( - b->CreateAnd(is_smaller_index, - b->CreateICmpSLT(compare_keys_index[dimension_to_sort], - keys_index.GetConstantWithIndexType( - dimension_to_sort_bound))), - "smaller_comparison_index", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_data.true_block, b); - auto key1 = keys_array.EmitReadArrayElement(keys_index, b); - auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); - auto compare_key1 = key1; - auto compare_key2 = key2; - auto key_type = keys_array.GetShape().element_type(); - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the sort - // has a predictable behavior in the presence of NaNs. Rather than using - // floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; + +// Adds the inner comparison loop body where we compare elements. +void EmitCompareLoopBody( + int64 iteration_bound, PrimitiveType key_type, int64 num_values, + int64 iota_values_parameter_index, llvm::Value* element_pair_index, + int64 xor_mask, llvm::Type* index_type, + std::function read_element, + std::function + write_element, + llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { + auto index_typed_constant = [&](int64 value) { + return llvm::ConstantInt::get(index_type, value); + }; + // The 'xor_mask' determines which elements are compared against each other. + // Index 'current_keys_index' will be compared with 'current_keys_index' xor + // 'xor_mask'. This means that we will always compare a block of consecutive + // elements against elements from the adjacent block of the same size. When + // 'xor_mask' is a power of 2, it immediately identifies the size of such a + // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In + // that case, we essentially flip the last 'k' - 1 bits when computing the + // position of the element to compare to, so the block size is 2^(k - 1). + int64 block_size = xor_mask; + // Check if it is a value 2^k - 1. + if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) { + block_size = (xor_mask + 1) / 2; } - auto comparison = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - // If key2 < key1 - auto if_smaller_data = - EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_smaller_data.true_block, b); - // Swap key1 with key2. - keys_array.EmitWriteArrayElement(keys_index, key2, b); - keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - for (const auto& values_array : values_arrays) { - // Also swap the values. - auto value1 = values_array.EmitReadArrayElement(keys_index, b); - auto value2 = values_array.EmitReadArrayElement(compare_keys_index, b); - values_array.EmitWriteArrayElement(keys_index, value2, b); - values_array.EmitWriteArrayElement(compare_keys_index, value1, b); + auto current_keys_index = element_pair_index; + if (block_size == 1) { + // If the block size is 1, we take every second element and compare it to + // the next one. + current_keys_index = + b->CreateMul(current_keys_index, index_typed_constant(2)); + } else if (block_size * 2 < iteration_bound) { + // current_keys_index iterates through the 'left' elements of the element + // pairs to be compared. We first need to compute the comparison block to + // which the element belongs. The block id of that block is index / + // block_size. + auto block_id = + b->CreateUDiv(current_keys_index, index_typed_constant(block_size)); + // The index of the 'left' element within its block is simply the remainder + // when dividing by 'block_size'. + auto index_within_block = + b->CreateURem(current_keys_index, index_typed_constant(block_size)); + // The first element of the 'left' block of elements that is compared + // against elements from the adjacent 'right' block of elements is + // 'block_id' * (2 * 'block_size'). + auto first_element_in_block = + b->CreateMul(block_id, index_typed_constant(2 * block_size)); + current_keys_index = + b->CreateAdd(first_element_in_block, index_within_block); } + auto compare_keys_index = + b->CreateXor(current_keys_index, index_typed_constant(xor_mask)); + // current_keys_index < compare_keys_index + llvm::Value* is_smaller_index = + b->CreateICmpSLT(current_keys_index, compare_keys_index); + // compare_keys_index < iteration_bound + llvm::Value* index_is_inbounds = b->CreateICmpSLT( + compare_keys_index, index_typed_constant(iteration_bound)); + llvm::Value* do_comparison = + needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds) + : b->getInt1(true); + + // if (is_smaller_index && index_is_inbounds) + KernelSupportLibrary ksl(b); + ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + auto key1 = read_element(0, current_keys_index); + auto key2 = read_element(0, compare_keys_index); + auto compare_key1 = key1; + auto compare_key2 = key2; + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } + // If key2 < key1 + auto is_smaller_than = + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1); + if (iota_values_parameter_index >= 0) { + auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); + auto key_index1 = + read_element(iota_values_parameter_index, current_keys_index); + auto key_index2 = + read_element(iota_values_parameter_index, compare_keys_index); + auto index_is_smaller_than = + b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); + is_smaller_than = b->CreateOr( + is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + } + ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); + }); +} + +void EmitTiledCompareLoop( + const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, + int64 dimension_to_sort_bound, PrimitiveType keys_type, + absl::Span xor_masks, const std::vector& params, + const std::vector& param_shmem_buffers, + int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { + KernelSupportLibrary ksl(b); + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); + llvm_ir::AddRangeMetadata(0, tile_size / 2, + llvm::cast(thread_id)); + thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), + /*isSigned=*/true, "thread.id.x"); + + auto copy_loop_body = + [&](std::function + read_or_write) { + auto value_one = tiled_keys_index.GetConstantWithIndexType(1); + auto current_keys_index = + b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); + // We want to copy two adjacent elements. We first check whether the + // first index position is within bounds. + ksl.IfReturnVoid( + "smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + auto cache_index = b->CreateShl(thread_id, value_one); + read_or_write(cache_index, current_keys_index); + // Increment to go the next index position. + current_keys_index = b->CreateAdd(current_keys_index, value_one); + // Here we check whether the next index position is within bounds. + ksl.IfReturnVoid( + "inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); + }); + }; + + // Copy operand tiles from the operand buffers to shared memory. + IrArray::Index keys_index = tiled_keys_index; + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = params[i].EmitReadArrayElement(keys_index, b); + b->CreateStore(value, + b->CreateGEP(param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), + cache_index})); + }); + } + // Wait until all reads have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + + // Now emit the bodies of the comparison loops. + auto read_element = [&](int64 operand, llvm::Value* index) { + return b->CreateLoad( + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + b->CreateStore( + value, + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + for (int64 xor_mask : xor_masks) { + // The index of the element pair to be compared within the tile stored in + // shared memory. We order the element pairs by the element with the smaller + // index. + auto element_pair_index = thread_id; + // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't + // need any bounds checks. + if (dimension_to_sort_bound % tile_size) { + // Otherwise we need a bounds check for the last tile. The last tile has + // size 'dimension_to_sort_bound' % 'tile_size'. + ksl.IfReturnVoid( + "is_last_tile", + b->CreateICmpUGE( + b->CreateMul(tiled_keys_index[dimension_to_sort], + tiled_keys_index.GetConstantWithIndexType(2)), + tiled_keys_index.GetConstantWithIndexType( + RoundDownToNearest(dimension_to_sort_bound, tile_size))), + [&]() { + EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, + params.size() - 1, iota_values_parameter_index, + element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b); + }, + [&]() { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), + read_element, write_element, b, + /*needs_bounds_checks=*/false); + }); + } else { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + } + // Wait until all comparisons have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + } + + // Copy the operand tiles back from shared memory to the operand buffers. + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = b->CreateLoad(b->CreateGEP( + param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); + params[i].EmitWriteArrayElement(keys_index, value, b); + }); + } + // We should normally synchronize here to make sure all writes have happened. + // However the very next thing each thread does is reading 2 elements from the + // operand buffer and writing it into the same location in shared memory from + // which it previously copied it to the operand buffer, and we synchronize + // after this has happened. We can be sure that a thread always writes to the + // same location in shared memory because we have exactly tile_size / 2 many + // threads, and the linear index calculated by ParallelLoopEmitter uses + // linear_index = blockIdx.x * blockDim.x + threadIdx.x; } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions) { - const Shape& keys_shape = keys_array.GetShape(); + int64 iota_values_parameter_index, + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, + const int64 tile_size) { + // Iterate through the keys shape in physical order, but skip the dimension to + // sort and make it the innermost loop which is the loop where the comparisons + // happen. In the dimension to sort, if we use tiling, we iterate through it + // in tiles of 64 elements each, so we use another loop that happens within + // one thread to process this tile worth of data (thereby combining several + // comparison stages of the bitonic sort algorithm because they all happen + // within those 64 elements and are therefore independent of the other + // comparisons). - // Create loop nests which loop through the operand dimensions. The sort - // dimension is handled in the innermost loop which performs the sorting. - ForLoopNest loop_nest(name, b); - IrArray::Index keys_index = - loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); - if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); + const Shape& keys_shape = keys_array.GetShape(); + int64 rank = ShapeUtil::Rank(keys_shape); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::vector dimensions_in_iteration_order(rank); + std::vector iteration_order_to_logical_order(rank); + int64 dim = 0; + for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) { + if (dimension != dimension_to_sort) { + dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension); + iteration_order_to_logical_order[dim++] = dimension; + } } + dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim; + iteration_order_to_logical_order[dim] = dimension_to_sort; - // 'compare_keys_index' is the index of the element that 'keys_index' should - // be compared to. - IrArray::Index compare_keys_index(keys_index.GetType()); - for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { - if (dimension != dimension_to_sort) { - compare_keys_index.push_back(keys_index[dimension]); - } else { - compare_keys_index.push_back(nullptr); + Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + dimensions_in_iteration_order); + std::vector params(1, keys_array); + params.insert(params.end(), values_arrays.begin(), values_arrays.end()); + + // Allocate shared memory for the tiled compare loop. + std::vector param_shmem_buffers(params.size(), nullptr); + if (xor_masks.size() > 1) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + for (int64 i = 0; i < params.size(); ++i) { + llvm::Type* tile_type = + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + params[i].GetShape().element_type(), module), + tile_size); + param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( + module, tile_type, absl::StrCat(name, "_tile_param_", i)); } } - // Naive C++ code for the inner compare loop: - // - // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { - // int64 j = i ^ xor_mask; - // if (i < j && j < dimension_to_sort_bound) { - // int64 min_key = std::min(keys[i], keys[j]); - // keys[j] = std::max(keys[i], keys[j]); - // keys[i] = min_key; - // } - // } - // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), - {dimension_to_sort_bound}); auto compare_loop_body_emitter = - [&](const IrArray::Index& compare_index) -> Status { - keys_index[dimension_to_sort] = compare_index[0]; - compare_keys_index[dimension_to_sort] = - b->CreateXor(compare_index[0], xor_mask); - EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_arrays, b); + [&](const IrArray::Index& tiles_index) -> Status { + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // /* emitted in EmitCompareLoopBody() */ + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + IrArray::Index keys_index(tiles_index.GetType(), rank); + for (int64 i = 0; i < rank; ++i) { + keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + } + if (xor_masks.size() > 1) { + EmitTiledCompareLoop(keys_index, dimension_to_sort, + dimension_to_sort_bound, keys_shape.element_type(), + xor_masks, params, param_shmem_buffers, + iota_values_parameter_index, tile_size, b); + } else { + auto read_element = [&](int64 operand, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + return params[operand].EmitReadArrayElement(keys_index, b); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + keys_index[dimension_to_sort] = index; + params[operand].EmitWriteArrayElement(keys_index, value, b); + }; + EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), + values_arrays.size(), iota_values_parameter_index, + tiles_index[rank - 1], xor_masks[0], + tiles_index.GetType(), read_element, write_element, + b); + } return Status::OK(); }; - if (launch_dimensions != nullptr) { - TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, - compare_shape, - *launch_dimensions, b) - .EmitLoop(name)); - } else { - TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) - .EmitLoop(name)); - } - - // Set the IR builder insert point to the exit basic block of the outer most - // loop. This ensures later instructions are inserted after this loop nest. - b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); + return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape, + launch_dimensions, b) + .EmitLoop(name); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 2f3bcda2307bcbb35a03b9e71dbbe44e366b3820..685f9383acba416f51681270e4037d56abb4b6ea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,13 +30,17 @@ namespace xla { namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, -// the inner compare loop will not be parallelized. +// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' +// contains only powers of 2, or values 2^k - 1 (k > 0). If +// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand +// that is a iota and can be used to make the sorting stable. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions); + int64 iota_values_parameter_index, + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index cca37556173bb95ef062b59ab0a4bf9ca7c496fe..2180ac845dd71da3a67b0a818866540764ce0848 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -220,4 +220,10 @@ StatusOr LocalService::GlobalDataToShapedBuffer( return buffers[replica_number]; } +StatusOr LocalService::RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag) { + return allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), tag); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 3b4f0b50832d6d2b64528ffb63eb5c7375396aec..f56ba32b04b9bf3aba75654bdb98887ad22e6791 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -63,6 +63,11 @@ class LocalService : public Service { StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 84059dd0f71ee8fc0a25703cbab2268d7dc149a8..fd18bfdc3e7f4b5f94237c554c3e6ca8bd065a35 100644 --- a/tensorflow/compiler/xla/service/map_inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using MapInlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(MapInlinerTest, MapMax) { @@ -59,12 +59,12 @@ TEST_F(MapInlinerTest, MapMax) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -93,12 +93,12 @@ TEST_F(MapInlinerTest, MapConstant) { HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -131,12 +131,12 @@ TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e..9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -257,7 +258,7 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, } void MultiOutputFusion::RecomputeReachability() { - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); } void MultiOutputFusion::UpdateReachability( @@ -317,9 +318,9 @@ bool MultiOutputFusion::Perform() { << instr2->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); } + Update(instr1, instr2); HloInstruction* ret = Fuse(instr1, instr2); set_is_fused(ret == instr1 ? instr2 : instr1); - Update(instr1, instr2); changed = true; VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" << ret->fused_instructions_computation()->ToString( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9508ab2ed1d38ec40983d8892ec8875b848fb21b..1c7583ece720f9e4d4b71a6279b976fed40e10cb 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 6152cdc6099a182f1ed98f9501613e0aa123cdbb..f196d9b7f586474f4a5e997b26acf93b732afdda 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -90,8 +90,8 @@ namespace xla { // are provided below. // // Example nullary instruction: -// Param() == Op().WithOpcode(HloOpcode::kParam) -// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) +// Parameter() == Op().WithOpcode(HloOpcode::kParameter) +// Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter) // // Example unary instruction: // Abs() == Op().WithOpcode(HloOpcode::kAbs) @@ -1067,8 +1067,10 @@ XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(Broadcast) XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(CrossReplicaSum) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c522e7ae23b734090f85d241bf365fccc37f0adb..c227106511c2c17b44569d3b696cd7d764226e81 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,20 +59,15 @@ string CanonicalPlatformName(const string& name) { /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { - se::MultiPlatformManager::PlatformMap platform_map; - se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( - [&platform_map](se::MultiPlatformManager::PlatformMap* map) { - platform_map = *map; - return se::port::Status::OK(); - }); - if (platform_map.empty()) { + std::vector all_platforms = + se::MultiPlatformManager::AllPlatforms(); + if (all_platforms.empty()) { LOG(WARNING) << "no executor platforms available: platform map is empty"; } // Gather all platforms which have an XLA compiler. std::vector platforms; - for (auto& platform_pair : platform_map) { - auto* platform = platform_pair.second; + for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { platforms.push_back(platform); @@ -222,8 +217,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // fix the number of devices to one. However we do let the user override // this behavior to help run tests on the host that run models in parallel // across multiple devices. - device_count = legacy_flags::GetDebugOptionsFromFlags() - .xla_force_host_platform_device_count(); + device_count = + GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 688cceff0cd10df62a4093f00ad3331ca77652e0..b70cb7057477a338bfb36ebab76237b30d018e41 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -111,7 +111,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( VLOG(2) << "Adding to operand " << i << ": " << operand; if (!is_valid_shape(operand->shape())) { - VLOG(2) << "Skipped: value is not an F32 vector"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } @@ -168,7 +168,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( << instruction->ToString(); if (!is_valid_shape(instruction->shape())) { - VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 0b4e82e8d606cf2cacfab42d07c2201939d5e10b..76c6a87f176ec9c6f8e49c25278c6dad703e3c7c 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass { // equivalent behavior can be obtained by adding ReducePrecision // instructions after the instructions that pull the F32 arrays out of // the tuples. - // - // TODO(b/64093391): Remove the IsScalar check once this won't cause - // failures on the GPU backend if the ReducePrecision instruction ends up - // inserted between a scalar constant and the init_value argument of a - // Reduce operation. - return shape.element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(shape); + return shape.element_type() == PrimitiveType::F32; } // Is this instruction one such that following or preceding it with a new diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 69e4b534bd8e3aeab8b729f3e594a10b4368f15f..efeec96571455d8a9e4b7837dd7286392c12f1a3 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,34 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); +} + +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -84,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -113,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -146,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -178,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -215,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -242,7 +269,7 @@ TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -268,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -294,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -321,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -349,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -375,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -411,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -458,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index fcf269eee925c2ddb7511d70e71bd815e4b8c24a..341659b15c4c7355d39739ee171a4a749d87e929 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,9 +34,10 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase {}; +class ReshapeMoverTest : public HloTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -50,12 +51,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -74,6 +75,7 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { // Verifies that the reshape is not moved, since rng0 is trivially reshapable // and therefore there is no nontrivial reshapes to move. TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( @@ -92,18 +94,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); } TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -117,12 +120,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -130,6 +133,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -143,11 +147,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -177,6 +181,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { // | // reshape4 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction( @@ -196,12 +201,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); @@ -221,6 +226,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { // Verifies that the reshape0 does not sink below add, because param1 is not // trivially reshapable nor is a Reshape/Transpose. TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -232,11 +238,11 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -257,6 +263,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // Verifies that we don't unnecessarily sink reshapes, which are in fact // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -275,12 +282,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -309,6 +316,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { // // (note that reshape1 here is trivial). TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -320,12 +328,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -348,6 +356,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { // For now we treat it as non-trivial, so we verify that we don't sink the // reshapes in this case. TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -362,12 +371,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); @@ -376,6 +385,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -389,14 +399,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -405,6 +415,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); @@ -423,13 +434,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -438,6 +449,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { } TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto pred_shape = ShapeUtil::MakeShape(PRED, {}); @@ -452,11 +464,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -477,6 +489,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { // // We expect reshape{0,1} AND reshape{2,3} to be lifted. TEST_F(ReshapeMoverTest, MultiplePasses) { + auto m = CreateNewVerifiedModule(); auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); @@ -500,14 +513,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -526,11 +539,11 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Transpose(op::Multiply())); } @@ -555,8 +568,8 @@ TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_FALSE(changed); } @@ -580,10 +593,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); } @@ -597,10 +610,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Add())); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 6f9094a5c2e882f4bc1531efdef654a6afa2ddb6..13fd6bc0093f3bb94c61fc46dc16ecfea03eb326 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -292,7 +292,7 @@ StatusOr> Service::CreateModuleConfig( config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { - config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config->set_debug_options(GetDebugOptionsFromFlags()); } if (execute_backend_ != nullptr && @@ -760,38 +760,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - ExecuteGraphParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - -Status Service::PickParallelResponse( - const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { - // The "result device" selection is a bit hacky, but better than assuming it - // is device 0. We have b/76035356 for restructuring the client API to clean - // up the current asymmetries and support more functionalities. - for (int64 i = 0; i < parallel_result.responses_size(); ++i) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica( - parallel_result.responses(i).output(), 0)); - const Shape& shape = buffer->on_host_shape(); - if (!ShapeUtil::IsEmptyTuple(shape)) { - *result = parallel_result.responses(i); - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return Status::OK(); - } - } - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - VLOG(1) << "Defaulting to device 0 result"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -836,10 +804,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute-graph request"; - +Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { + VLOG(1) << "running compile request"; if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -847,22 +813,21 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return InvalidArgument("programe shape may not be empty"); } - // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); + return InvalidArgument( + "The compile request does not support multiple device handles."); } - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - + std::vector argument_shapes; + absl::c_transform(arg->input_shape_with_layout(), + std::back_inserter(argument_shapes), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(arg->computation().host_program_shape(), - replicated_arguments.front(), - arg->execution_options())); + argument_shapes, &arg->execution_options())); + VLOG(3) << "Compile created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -871,6 +836,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); + + VLOG(1) << "successfully completed 'compile' request"; + return Status::OK(); +} + +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { + VLOG(1) << "running execute request"; + if (!arg->has_handle()) { + return InvalidArgument("execution handle should not be empty"); + } + TF_ASSIGN_OR_RETURN(auto executable, + compilation_cache_.LookUp(arg->handle())); + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + // Check that the replicated_arguments has the same shape and layout as the + // module config used when creating the exectuable. + const int64 num_module_args = + executable->module_config().entry_computation_layout().parameter_count(); + if (num_module_args != arg->arguments_size()) { + return InvalidArgument( + "The executable expects %lld arguments, but sees %lld.", + num_module_args, arg->arguments_size()); + } + for (int64 i = 0; i < num_module_args; i++) { + const Shape& shape_module = + executable->module_config().entry_computation_layout().parameter_shape( + i); + const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape(); + if (!ShapeUtil::Equal(shape_module, shape_arg)) { + return InvalidArgumentStrCat( + "The executable exepcts the ", i, "th argument in shape ", + ShapeUtil::HumanStringWithLayout(shape_module), " but sees ", + ShapeUtil::HumanStringWithLayout(shape_arg)); + } + } + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( execute_backend_->default_stream_executor())); @@ -884,9 +891,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + arg->computation().name(), result->mutable_profile())); + ExecuteAndRegisterResult(executable.get(), replicated_arguments, + execute_backend_.get(), + "result of " + executable->module().name(), + result->mutable_profile())); if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN( @@ -898,7 +906,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - VLOG(1) << "successfully completed 'execute-graph' request"; + VLOG(1) << "successfully completed 'execute' request"; return Status::OK(); } @@ -949,21 +957,6 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, return Status::OK(); } -namespace { - -// Creates a clone of the given shaped buffer with the given device ordinal. The -// shape and DeviceMemoryBase values of the clone are identical to the original. -std::unique_ptr CloneShapedBufferOnDevice( - const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = absl::make_unique( - shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), - shaped_buffer.platform(), device_ordinal); - clone->buffers() = shaped_buffer.buffers(); - return clone; -} - -} // namespace - Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(Literal literal, diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 8cf1a7b9f01fbb3572c6849c8b18e14174ced89f..11e1a79552fbd944ab28da129b08cfe676fb08e9 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" +#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" @@ -90,11 +91,14 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Executes a computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. - // Returns global data output and execution timing. - Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + // Compiles a computation into an executable. The request contains the whole + // computation graph. Returns the handle to the executable. + Status Compile(const CompileRequest* arg, CompileResponse* result) override; + + // Executes an executable with the provided global data passes as immutable + // arguments. The request contains the handle to the executable. Returns + // global data output and execution timing. + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each @@ -179,10 +183,6 @@ class Service : public ServiceInterface { absl::Span arguments, const ExecutionOptions& execution_options); - // Picks a parallel response and fills the result. - Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, - ExecuteResponse* result); - // Prepare the executors for executing parallel. StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, @@ -254,11 +254,6 @@ class Service : public ServiceInterface { Backend* backend, absl::Span device_handles, absl::Span result_tags, ExecutionProfile* profile); - // Executes a single computation which has more than one target device. - // The N devices are expected to all return an empty tuple, but one, which - // will be the result of this computation. - Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. @@ -281,6 +276,9 @@ class Service : public ServiceInterface { ServiceOptions options_; + // Cache containing previously built Executables. + CompilationCache compilation_cache_; + // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2f8f092303ed1821d9bff021da0e835f1878f5ed..2bfc1676bddc66bdc90052589ed3024510c24d8f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2031,6 +2031,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } +/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( + const Shape& shape, int64 dimension) { + if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", + dimension); + } + + // TODO(b/119580730): Remove this restriction when very large dimension size + // is needed. + if (shape.dimensions(dimension) > std::numeric_limits::max()) { + return InvalidArgument( + "GetDimensionSize's input shape is %s, the %dth dimension exceeds the " + "UINT_MAX limit.", + ShapeUtil::HumanString(shape), dimension); + } + + return ShapeUtil::MakeShape(U32, {}); +} + /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { @@ -2833,6 +2852,15 @@ Status ValidateScatterDimensionNumbers( } } + // Validate window size. + auto window_size = dim_numbers.update_window_dims_size() + + dim_numbers.inserted_window_dims_size(); + if (window_size != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Scatter op has window of size %d; doesn't match operand of rank %d.", + window_size, ShapeUtil::Rank(operand_shape)); + } + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. if (dim_numbers.scatter_dims_to_operand_dims_size() != scatter_indices_shape[dim_numbers.index_vector_dim()]) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index cd4e5ab52ca5e33424f2b78f83cc94961b254493..31ef4b2e41078f87731a1eff58e37409a6004ba4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -291,6 +291,9 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + static StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64 dimension); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7b65e8c1c9d2bc730c6c8550e9265b69fdde71cf..4639e32db4d59080a9e85e46983fac61d9e76be9 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -2673,5 +2673,23 @@ TEST_F(ScatterGatherShapeInferenceTest, << statusor.status(); } +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_InsufficientWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Scatter op has window of size 4; doesn't match operand of rank 5.")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 56952e3adae59656605a12fd499162504a2a3379..28a30b5ee2dbcb5012804578d4d037c241045309 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -157,4 +157,23 @@ void ScopedShapedBuffer::Deallocate() { } } +ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { + const xla::Shape& sub_on_host_shape = + xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); + const xla::Shape& sub_on_device_shape = + xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); + + ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, + memory_allocator(), device_ordinal()); + auto src_it = buffers().find(index); + auto dst_it = output.buffers().begin(); + while (dst_it != output.buffers().end()) { + dst_it->second = src_it->second; + src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0); + ++src_it; + ++dst_it; + } + return output; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e1d26da4a20c0105be304b1a34c81515fcdc6b7f..f5210c9cfa6b29853bcd0f5bfd581ee3e116a509 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -176,6 +176,11 @@ class ScopedShapedBuffer : public ShapedBuffer { // It's the caller's job to ensure that the memory contained therein is freed. TF_MUST_USE_RESULT ShapedBuffer release(); + // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer + // that holds ownership of the subtree. Sets the buffers corresponding to the + // subtree to null in 'this'. + ScopedShapedBuffer TakeSubTree(ShapeIndexView index); + protected: void Deallocate(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index d69e6362e91e4696dab3c46d99a981c67b593a1c..ca64bd3c8dd2baa686db2b85c937a034b37ab22b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -107,5 +109,79 @@ TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { // TestAllocator's destructor checks that all memory was freed. } +TEST(ScopedShapedBufferTest, TestTakeSubTree) { + TestAllocator allocator; + + Shape s = ShapeUtil::MakeShape(F32, {1}); + s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); + s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); + + ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + sb.buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + TF_ASSERT_OK_AND_ASSIGN( + OwningDeviceMemory m, + allocator.Allocate(/*device_ordinal=*/0, /*size=*/77)); + *buffer = m.Forget(); + }); + ShapeTree buffers = sb.buffers(); + + // Takes a subtree out of 'sb', and verifies the buffers are as expected. + xla::ShapeIndex subtree_index = {1}; + ScopedShapedBuffer output = sb.TakeSubTree(subtree_index); + + output.buffers().ForEachElement([&](const xla::ShapeIndex& sub_index, + const se::DeviceMemoryBase& buffer) { + xla::ShapeIndex orig_index = subtree_index; + for (int i : sub_index) { + orig_index.push_back(i); + } + EXPECT_TRUE(buffers.find(orig_index)->second.IsSameAs(buffer)); + }); + sb.buffers().ForEachElement( + [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + if (ShapeIndexView(index).StartsWith(subtree_index)) { + EXPECT_TRUE(buffer.is_null()); + } else { + EXPECT_TRUE(buffers.find(index)->second.IsSameAs(buffer)); + } + }); +} + +// Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_TakeSubTree(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + TestAllocator allocator; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)shaped_buffer.TakeSubTree(/*index=*/{fan_out / 2}).release(); + } + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_TakeSubTree) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 79b5c09abb355cd067a4891af558c8c44d80d88e..17cdaa74fc328d156292f5af828d4222a9a01f1f 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewModule("fuse_with_constant_operands"); + auto module = CreateNewVerifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index d9ebebf74ed846aa05326a4df72019ef3e71ad88..10ef2d38fa21c3e93c270535bc99b2f76435337d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -48,7 +48,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { } void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); module_->AddEntryComputation(std::move(computation)); } @@ -809,7 +809,7 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { class PointsToAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -1176,7 +1176,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -1211,7 +1211,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 516754e2110ee50a597818c4a8bcfbfbb76c5cec..65b0f8c804475d8f22fff9798e79c9881a51f1f1 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloVerifiedTestBase { +class TupleSimplifierTest : public HloTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -65,10 +65,10 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, scalar_shape_, "param2")); builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -78,10 +78,10 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { HloInstruction::CreateParameter(0, tuple_shape_, "param")); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -98,12 +98,12 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), gte); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -125,13 +125,13 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -157,12 +157,12 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), element); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -182,12 +182,12 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -207,19 +207,19 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { // Verify that the root computation can be excluded - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloInstruction* p0; HloInstruction* p1; @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module, /*change_expected=*/true, /*exclude_entry=*/true); + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 541b117e0299c94de330604ec5c16e20f07c425f..68e2569f66bea9ec1223e454d1ead0efc7b9498e 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" namespace xla { @@ -229,4 +232,96 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, return nullopt; } +// If the only user of this instruction is a get-tuple-element, return that +// get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may +// get a false negative if there are several copies of the same GTE, or there +// are unused GTEs, but we can live with this. +static HloInstruction* GetOnlyGTE(HloInstruction* inst) { + if (inst->user_count() != 1) { + return nullptr; + } + + HloInstruction* user = inst->users().back(); + if (user->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; + } + return user; +} + +optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { + // If we know the exact trip count, it's also the upper bound. + auto exact_trip_count = ComputeWhileLoopTripCount(while_op); + if (exact_trip_count) { + VLOG(2) << "Loop has exact trip count."; + return exact_trip_count; + } + + // There is one more case we know how to handle. If the loop condition only + // looks at one element of the tuple, and the loop body sets this element to a + // constant, there are two options: + // 1) Evaluating the condition on this constant returns true. In this case, + // the loop either executes 0 times, or is an infinite loop, depending on the + // init value. + // 2) Evaluating the condition on this constant returns false. In this case, + // the loop executes 0 or 1 times, depending on the init value. This means + // that, regardless of the init value, the upper bound on the trip count is 1. + + // Check whether the condition depends on a single parameter, and find out + // which. + auto* while_cond = while_op->while_condition(); + auto* while_cond_param = while_cond->parameter_instruction(0); + auto* cond_gte = GetOnlyGTE(while_cond_param); + if (!cond_gte) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // Now check whether this gets set to a constant by the while body. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(3) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + int64 indvar_index = cond_gte->tuple_index(); + auto* while_body_indvar = while_body_root->operand(indvar_index); + if (while_body_indvar->opcode() != HloOpcode::kConstant) { + VLOG(3) << "While body does not set the IV to a constant: " + << while_body_indvar->ToString(); + return nullopt; + } + + // We have a constant. Evaluate the condition on this constant. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), + /*dest_shape_index=*/{indvar_index}, + /*src_shape_index=*/{})); + StatusOr eval_result = + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + + if (!eval_result.ok()) { + VLOG(2) << "Couldn't evaluate while loop condition."; + return nullopt; + } + + Literal cond_result_pred = std::move(eval_result.ValueOrDie()); + CHECK(ShapeUtil::Equal(cond_result_pred.shape(), + ShapeUtil::MakeShape(PRED, {}))); + + // Per the explanation above, if the evaluated condition returns false, the + // loop executes at most once. + bool cond_returns_true = cond_result_pred.GetFirstElement(); + if (!cond_returns_true) { + VLOG(2) << "Upper bound on the trip count is 1"; + return 1; + } + + VLOG(2) << "Loop has no known upper bound on the trip count."; + return nullopt; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf497f4892b95c927379411468a66d8961465413..ac69a727bd6b403672a676400993fb7d8afc0a55 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -28,6 +28,10 @@ namespace xla { absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, int64 max_value_returned = 128); +// Returns an upper bound on the trip count of the loop if it's statically +// known, nullopt otherwise. +absl::optional ComputeWhileLoopTripCountUpperBound( + HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1da0fbeac89a93eaaef893e5f25dd3b87cc1d5d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class WhileLoopAnalysisTest : public HloTestBase {}; + +TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(-1) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); +} + +TEST_F(WhileLoopAnalysisTest, NoUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(42) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, ExactBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + index = s32[] get-tuple-element(p_body), index=1 + one = s32[] constant(1) + inc = s32[] add(index, one) + ROOT root = (f32[2], s32[]) tuple(val, inc) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] less-than(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] constant(0) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 067cfcc17d65860a249de4d9e31703df12091d3a..8b381dec07397c1427e98bc30511ac21dc577610 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -46,8 +46,9 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( return Status::OK(); } -StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); const HloInstruction& init_value = *while_instr->operand(0); @@ -57,24 +58,48 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( bool changed = false; - for (HloInstruction* invariant_gte : - WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - int64 index = invariant_gte->tuple_index(); + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64 index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); - // Should have at least one user that's not while_body_root. - if (invariant_gte->user_count() <= 1) { + // Original value should be a constant. + if (invariant_value.opcode() != HloOpcode::kConstant) { continue; } - if (invariant_value.opcode() == HloOpcode::kConstant) { - auto* constant_instr = + // Sink into the while_body. + // Should have at least one user that's not while_body_root. + if (invariant_body_gte->user_count() > 1) { + HloInstruction* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_gte, constant_instr, while_body->root_instruction(), + invariant_body_gte, constant_instr, while_body->root_instruction(), index)); changed = true; } + + // Check if there is a corresponding GTE in while_conditional. + auto it = conditional_gte_index_to_insts.find(index); + if (it == conditional_gte_index_to_insts.end()) { + continue; + } + + for (HloInstruction* invariant_cond_gte : it->second) { + // Should have at least one user. + if (invariant_cond_gte->user_count() > 0) { + HloInstruction* constant_instr = while_cond->AddInstruction( + invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR( + invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); + changed = true; + } + } } return changed; @@ -115,10 +140,8 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { } for (HloInstruction* while_instr : while_instrs) { - // We only sink into while loop bodies, but this can be extended to - // transform conditions as well. TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileBody(while_instr)); + TrySinkingConstantsIntoWhileLoop(while_instr)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 577bad6c7062d2ee40271e407e8eed7655fa13bf..a866bc1264b4013bb7530b5e02b546e6f78d676b 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { // Sinks while loop invariant values that happen to be constants into the while -// loop body. This is probably not a win in isolation but may unlock further -// optimizations like constant folding. +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like constant folding. // // state = (..., const, ...) // while (pred(state)) { @@ -46,22 +46,19 @@ namespace xla { // tuple trivially loop invariant. WhileLoopSimplifier will later get rid of // `v`. // -// We only sink into while loop bodies, but this can be extended to transform -// conditions as well. -// // TODO(b/79121449): We should also sink broadcasts of constants. class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; absl::string_view name() const override { - return "while-loop-invariant-code-motion"; + return "while-loop-constant-sinking"; } StatusOr Run(HloModule* module) override; private: - StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); + StatusOr TrySinkingConstantsIntoWhileLoop(HloInstruction* while_instr); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index d17b86fab5b14d13250a03fc8f74abb9661ed5ce..75d406435b6f58faecc86b82c33e9e2dd6bccbea 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -242,5 +242,178 @@ ENTRY entry { } } } + +TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1 + ROOT root = (f32[],f32[]) tuple(add, p_body.1) +} + +condition { + p_cond = (f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + while_init = (f32[],f32[]) tuple(const_0, const_1) + ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[],(f32[],f32[])) parameter(0) + p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0 + p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1 + p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0 + add = f32[] add(p_b.0, p_b.1.0) + ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1) +} + +condition { + p_c = (f32[],(f32[],f32[])) parameter(0) + p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 + p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 + p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 + ROOT result = pred[] less-than(p_c.0, p_c.1.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) + ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::Lt(_, op::GetTupleElement(op::Constant()))); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); + for (const HloInstruction* inst : while_condition->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add.0 = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + add.1 = f32[] add(p_body.1, const) + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.0 = pred[] less-than(p_cond.0, p_cond.2) + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + ROOT result = pred[] and(lt.0, lt.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(0) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 9795b2830b6d9add82b89ac76b5438ddc3d2bfe8..41011176ffa91e885bc58364d1fb19617d3518ad 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -19,7 +19,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -143,6 +145,12 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( string while_instr_name = while_instr->ToString(print_no_metadata); VLOG(2) << "Trying to hoist from " << while_instr_name; + auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr); + if (maybe_upper_bound && *maybe_upper_bound <= 1) { + VLOG(2) << "Loop has a trip count of at most 1, skipping."; + return false; + } + HloComputation* while_body = while_instr->while_body(); // Maps instructions in the while body to instructions hoisted outside the @@ -180,6 +188,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return false; } + // LICM in the presence of domain instructions is complex, bail. + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + return false; + } + } + // instructions_to_replace[i] is hoisted into a loop invariant instruction // replacement_instructions[i]. std::vector instructions_to_replace; @@ -193,6 +208,37 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( continue; } + if (!hoist_size_inflating_ops_) { + // Check that hoisting the instruction doesn't cause a significant memory + // blow-up. LICM extends the live-range of the output of the hoisted + // instruction to be the entire while loop, which may be problematic on + // platforms where memory is limited. This can be especially harmful if + // the instruction has a significantly larger output than its input, e.g. + // kIota, kBroadcast or kConstant. + int64 input_size = 0, output_size = 0; + + for (auto* operand : instruction->operands()) { + ShapeUtil::ForEachSubshape( + operand->shape(), + [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + input_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + output_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + + if (output_size > input_size) { + continue; + } + } + auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || unhoisted_invariant_instructions.count(op) || diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 3031899f71e0fd77f20448d9d7489798af01615c..bd6232dc0a988775a0490abbf6125daad8476295 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -34,8 +34,14 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { // Setting `hoist_constants` to false can be help if LICM is run in the mid // level HLO pipeline because hoisting constants out of while loop bodies can // break optimizations like constant folding. - explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) - : hoist_constants_(hoist_constants) {} + // Setting `hoist_size_inflating_ops` to false will forbid hoisting + // instructions where the size of the output(s) is larger than the size of the + // input(s). This is useful on platforms on which it's important to prevent + // blow-ups in memory size. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false, + bool hoist_size_inflating_ops = true) + : hoist_constants_(hoist_constants), + hoist_size_inflating_ops_(hoist_size_inflating_ops) {} ~WhileLoopInvariantCodeMotion() override = default; absl::string_view name() const override { @@ -49,6 +55,7 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { HloInstruction* while_instr); bool hoist_constants_; + bool hoist_size_inflating_ops_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b713c438bd7fcb2053709b0624f58ed..8e7c4bc8828552e197b41f874c070d496b85a382 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -26,7 +26,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { +class WhileLoopInvariantCodeMotionTest : public HloTestBase { public: // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. @@ -58,6 +58,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( } TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -76,19 +77,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -100,6 +100,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { } TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -135,19 +136,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -173,6 +173,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistTriviallyLoopVaryingComputation) { // Basic negative test: the add expression is not loop invariant. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -189,20 +190,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); @@ -210,6 +211,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -228,25 +230,26 @@ TEST_F(WhileLoopInvariantCodeMotionTest, builder.AddInstruction( HloInstruction::CreateTuple({gte_1, gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = @@ -267,7 +270,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -277,14 +280,14 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -294,6 +297,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the // bitcast either. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); @@ -317,7 +321,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -327,15 +331,15 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -346,6 +350,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); Shape while_shape = @@ -367,21 +372,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -396,6 +400,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -416,22 +421,23 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - while_body = module().AddEmbeddedComputation(builder.Build()); + while_body = m->AddEmbeddedComputation(builder.Build()); } HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -439,7 +445,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { HloComputation::Builder builder(TestName() + ".passthrough"); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + HloComputation* result = m->AddEmbeddedComputation(builder.Build()); result->AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); @@ -450,11 +456,11 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } @@ -482,14 +488,14 @@ ENTRY entry { )"; TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN( bool simplified_loop, - WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get())); EXPECT_TRUE(simplified_loop); - HloComputation* while_body = module().GetComputationWithName("wide.body"); + HloComputation* while_body = m->GetComputationWithName("wide.body"); ASSERT_NE(while_body, nullptr); // We expect the while body to be the equivalent of: @@ -523,10 +529,98 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { } TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0) + val.0 = f32[2] get-tuple-element(p_body), index=0 + val.1 = f32[2] get-tuple-element(p_body), index=1 + add = f32[2] add(val.0, val.1) + const = s32[] constant(-1) + ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const) + } + + condition { + p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=3 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1) + ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(module.get())); + EXPECT_FALSE(simplified_loop); +} + +const char* const kInflatingTestCase = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[]) parameter(0) + iota = f32[1024, 1024] iota(), iota_dimension=0 + add = f32[1024, 1024] add(iota, iota) + constant = f32[] constant(1.0) + reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul + ROOT root = (f32[]) tuple(reduce) +} + +condition { + p_cond = (f32[]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param = f32[] parameter(0) + while_init = (f32[]) tuple(param) + ROOT while = (f32[]) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true, + /*hoist_size_inflating_ops=*/false) + .Run(m.get())); EXPECT_FALSE(simplified_loop); } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 630d71e5ca25e9d282ce6283284a32d6f725a193..c4790a7f199a90ca81e5503b4256bd95df88d4f4 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -19,41 +19,19 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { +namespace m = match; using absl::optional; - -// Determines whether the given instruction is a send/recv node, or has a -// subcomputation which contains a send/recv node. -static bool IsOrContainsSendOrRecv(const HloInstruction* instr); - -// Determines whether the given computation contains a send or recv node. -static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto* instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr)) { - return true; - } - } - return false; -} - -static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kSendDone || - instr->opcode() == HloOpcode::kRecv || - instr->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (const auto& subcomp : instr->called_computations()) { - if (ContainsSendOrRecv(subcomp)) { - return true; - } - } - return false; -} +using hlo_query::ContainsInstrWithOpcode; // Tries to remove elements in a while loop's tuple that aren't used within the // loop. @@ -253,7 +231,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond), /*extras=*/{}); + make_while_computation_replacements(while_cond)); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -266,8 +244,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements), - /*extras=*/{}); + while_body->CloneWithReplacements(std::move(while_body_replacements)); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -329,6 +306,147 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return true; } +// Removes each loop parameter (i.e. member of the while loop tuple) that is a +// constant and is the same in the while loop body and the while loop init. +static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + + absl::flat_hash_set constant_tuple_indices; + const auto& while_shape = while_init->shape(); + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (init_elem->opcode() == HloOpcode::kConstant && + body_elem->opcode() == HloOpcode::kConstant && + init_elem->literal() == body_elem->literal()) { + constant_tuple_indices.insert(i); + } + } + + if (constant_tuple_indices.empty()) { + return false; + } + + // OK, we found some constant elements of the while parameter! Eliminate + // them. + std::vector new_while_shape_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + new_while_shape_elems.push_back(while_shape.tuple_shapes(i)); + } + } + Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + // Returns a new tuple without the elements of constant_tuple_indices. + auto remove_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), while_shape)); + + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, i))); + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + auto add_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + + std::vector tuple_elems; + int64 j = 0; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (constant_tuple_indices.count(i)) { + tuple_elems.push_back(while_init->mutable_operand(i)); + } else { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, j))); + ++j; + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Special case: constant_tuple_indices covers the whole while parameter, so + // the new while shape is the empty tuple. In this case, the value of the + // while loop is simply equal to the value of `init`. + // + // It's unfortunate to special-case this, but it's simpler than the + // alternative. The problem is that if our while parameter has no + // non-constant elems, the tuple returned by `add_constant_elems` won't depend + // on instr (the loop body/cond parameter), and therefore + // CloneWithReplacementPairs will *leave the parameter out entirely*, creating + // invalid HLO. + if (ShapeUtil::IsEmptyTuple(new_while_shape)) { + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); + return true; + } + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + remove_constant_elems( + add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, + add_constant_elems( + computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + add_new_instr(remove_constant_elems(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + // Tries to remove a while loop from the graph. // // - Loops with trip count of 0 can be replaced by the loop's "init" value. @@ -458,6 +576,414 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { return changed_cond || changed_body; } +// Converts a flat list of instructions into a tuple of the desired shape. For +// example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns +// a tuple of value ((A, B), C). +// +// desired_shape must be a tuple. (This precondition allows us to return a +// unique_ptr rather than a raw ptr.) +static std::unique_ptr UnflattenTupleInstr( + absl::Span instrs, const Shape& desired_shape, + std::vector>* new_instrs) { + CHECK(ShapeUtil::IsTuple(desired_shape)) + << ShapeUtil::HumanString(desired_shape); + + // For each child shape in `desired_shape`, slice out the correct number of + // `instrs` and call UnflattenTupleInstr recursively. At each step we remove + // elements from `instrs` so that it only contains instructions we have not + // yet processed. + std::vector elems; + for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { + const Shape& subshape = desired_shape.tuple_shapes(i); + if (!ShapeUtil::IsTuple(subshape)) { + elems.push_back(instrs[0]); + instrs.remove_prefix(1); + continue; + } + + // Count the number of leaf nodes underneath desired_shape[i]. + int64 num_leaves = 0; + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + ++num_leaves; + } + }); + + std::unique_ptr subinstr = + UnflattenTupleInstr(instrs.subspan(0, num_leaves), + desired_shape.tuple_shapes(i), new_instrs); + elems.push_back(subinstr.get()); + new_instrs->push_back(std::move(subinstr)); + instrs.remove_prefix(num_leaves); + } + return HloInstruction::CreateTuple(elems); +} + +// Builds a vector whose elements are the values in the flattened tuple for +// `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the +// vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C). +static std::vector GetFlatTupleElems( + HloInstruction* instr, + std::vector>* new_instrs) { + const auto& shape = instr->shape(); + if (!ShapeUtil::IsTuple(shape)) { + return {instr}; + } + std::vector elems; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + new_instrs->push_back( + HloInstruction::CreateGetTupleElement(subshape, instr, i)); + auto* gte = new_instrs->back().get(); + auto flattened_subshape = GetFlatTupleElems(gte, new_instrs); + elems.insert(elems.end(), flattened_subshape.begin(), + flattened_subshape.end()); + } + return elems; +} + +static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + if (!ShapeUtil::IsNestedTuple(while_shape)) { + return false; + } + + std::vector flattened_shape_elems; + ShapeUtil::ForEachSubshape(while_shape, + [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + flattened_shape_elems.push_back(s); + } + }); + Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto nested = [&](HloInstruction* instr) { + std::vector gtes; + const Shape& flat_shape = instr->shape(); + for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) { + gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement( + flat_shape.tuple_shapes(i), instr, i))); + } + auto nested_instr = + UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs); + CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape)) + << ShapeUtil::HumanString(nested_instr->shape()) << " vs " + << ShapeUtil::HumanString(while_shape); + return nested_instr; + }; + + auto flattened = [&](HloInstruction* instr) { + return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs)); + }; + + // Create a new while-condition computation, where parameter 0 has flat shape + // but all uses of it go through the nested shape. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Create a new while-body computation, where parameter 0 has a flat shape and + // all uses of it go through the nested shape, and where the root has a flat + // shape constructed from the old nested root. + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_body->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + flattened(add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile( + flattened_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + computation->AddInstruction(flattened(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + +// Tries to merge loop induction variables of a given type. +// +// In this pass we're only concerned with elements of the loop's tuple that +// are effective-scalars of type `elem_ty`. Some terminology: +// +// - The trip counter is the first element of the loop's tuple that starts at +// 0 and does x++ on each iteration. +// +// - An induction variable is an element of the loop's tuple that is not the +// trip counter and does `x += ` on each iteration of the loop. +// Negative constants are OK. +// +// This pass adds a trip counter if one isn't already present, then replaces +// each induction variable with +// +// + * . +// +// This reduces the number of scalar operations in the loop, which is important +// e.g. on GPUs, where each scalar operation is nontrivially expensive because +// it's a separate kernel launch. +// +// Returns the new loop if a change was made, or null if no change was made. +// Note that the new loop is not a valid replacement for the old loop; it may +// need to be wrapped in a tuple that changes its shape. We return the loop +// itself so that you can call TryMergeInductionVariables in a loop, once for +// each integral type elem_ty. +static StatusOr TryMergeInductionVariables( + HloInstruction* while_op, PrimitiveType elem_ty) { + CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return nullptr; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + + // The tuple index of the trip counter, if one is present. + absl::optional trip_counter; + // Maps the tuple index of each induction variable to its constant increment. + absl::flat_hash_map induction_vars; + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + const auto& elem_shape = while_body_root->operand(i)->shape(); + if (!ShapeUtil::IsEffectiveScalar(elem_shape) || + elem_shape.element_type() != elem_ty) { + continue; + } + + HloInstruction* constant; + if (!Match(while_body_root->mutable_operand(i), + m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i), + m::Constant(&constant)))) { + continue; + } + if (!trip_counter && constant->literal().IsAll(1) && + while_init->operand(i)->IsConstant() && + while_init->operand(i)->literal().IsAll(0)) { + VLOG(10) << "Found existing trip counter at index " << i; + trip_counter = i; + } else { + VLOG(10) << "Found induction variable at index " << i; + induction_vars.emplace(i, Cast(constant)); + } + } + + // There's only something to simplify if we can either: + // + // - combine one or more induction vars with an existing trip counter, or + // - replace two or more induction variables with a new trip counter. + // + // Put another way, there's only something to simplify if the number of + // induction vars plus the number of existing trip counters (0 or 1) is >= 2. + if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) { + return nullptr; + } + + // OK, we're going to do the transformation! Set up some helpers. + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto add_binary_op = [&](const Shape& shape, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // Reshape lhs/rhs to the output shape if necessary. This deals with the + // fact that induction variables need only be effective scalars, not true + // scalars. + if (!ShapeUtil::Compatible(shape, lhs->shape())) { + lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs)); + } + if (!ShapeUtil::Compatible(shape, rhs->shape())) { + rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs)); + } + return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs)); + }; + + auto add_gte = [&](HloInstruction* src, int64 idx) { + return add_new_instr(HloInstruction::CreateGetTupleElement( + src->shape().tuple_shapes(idx), src, idx)); + }; + + // Our new while loop will have the same shape as the old while loop, except + // we'll add a trip counter to the end if it wasn't originally present. + Shape new_while_shape = while_shape; + bool added_trip_counter = false; + if (!trip_counter) { + VLOG(10) << "Adding new trip counter to end of loop's tuple."; + trip_counter = new_while_shape.tuple_shapes_size(); + *new_while_shape.add_tuple_shapes() = + ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{}); + added_trip_counter = true; + } + + // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with + // shape `while_body->shape()` and where the induction variables are "reified" + // (i.e. they have value + * ). + auto convert_to_old_form = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + const auto& elem_shape = while_shape.tuple_shapes(i); + if (!induction_vars.count(i)) { + tuple_elems.push_back(add_gte(instr, i)); + continue; + } + tuple_elems.push_back(add_binary_op( + elem_shape, HloOpcode::kAdd, add_gte(instr, i), + add_binary_op(elem_shape, HloOpcode::kMultiply, + add_gte(instr, *trip_counter), + add_new_instr(induction_vars.at(i)->Clone())))); + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Converts `root` into a tuple of the "new" form -- that is, to a tuple with + // shape `new_while_shape` and where the induction variables (but not trip + // counters) are replaced with their unchanging values. + auto convert_to_new_form = [&](HloInstruction* old_root, + HloParameterInstruction* loop_body_param) { + CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape)); + std::vector tuple_elems; + + // In the new form, induction variables come from `init`, everything else + // (including the trip counter if it's not one we created ourselves) comes + // from the `root` tuple unmodified. + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back( + add_gte((induction_vars.count(i) ? loop_body_param : old_root), i)); + } + // If we created a trip counter ourselves, add 1 to it in the next + // iteration. + if (added_trip_counter) { + tuple_elems.push_back(add_binary_op( + new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd, + add_gte(loop_body_param, *trip_counter), + add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::One(elem_ty))))); + } + + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Creates a new init tuple, which is the same as the old init tuple except if + // we added a trip counter, it's set to 0. + auto get_new_while_init = [&](HloInstruction* init) { + CHECK(ShapeUtil::Compatible(init->shape(), while_shape)); + if (!added_trip_counter) { + return init; + } + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back(add_gte(init, i)); + } + tuple_elems.push_back(add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty)))); + return add_new_instr(HloInstruction::CreateTuple(tuple_elems)); + }; + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Creating the new while body proceeds in two steps. First we convert the + // users of the parameter to the old form. Then as a second + // CloneWithReplacement operation we convert the root to the new form. We + // have to do this in two steps because the new root needs to use the new + // param0, and during the first clone operation, only the *old-form* param0 is + // accessible. + // + // We have to add temp_new_while_body to the module because cloning a + // computation touches the module (to get its NameUniquer). + HloComputation* temp_new_while_body = + module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({ + while_body->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_body->parameter_instruction(0)->name()))), + })); + std::unique_ptr new_while_body = + temp_new_while_body->CloneWithReplacementPairs({ + temp_new_while_body->root_instruction(), + convert_to_new_form( + add_new_instr(temp_new_while_body->root_instruction()->Clone()), + Cast( + temp_new_while_body->parameter_instruction(0))), + }); + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + get_new_while_init(while_init))); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, convert_to_old_form(new_while))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return new_while; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -478,32 +1004,77 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { for (HloInstruction* while_op : while_ops) { // We can't remove while loops that contain send/recv nodes, because we rely // on the particular loop structure around the node matching on the send and - // recv sides. Removing dead while params requires us to remove the loop + // recv sides. Other while simplifications require us to remove the loop // and replace it with a new one, so we can't do that either. - if (ContainsSendOrRecv(while_op->while_body()) || - ContainsSendOrRecv(while_op->while_condition())) { + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { VLOG(2) << "Not attempting to simplify while loop because it contains a " "send/recv node: " << while_op->ToShortString(); continue; } - StatusOr result = TryPropagateConstant(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); + changed |= result; + + TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); + changed |= result; + if (result) { + // Don't continue simplifying after successfully removing the while loop + // -- that would result in use-after-free nastiness. + continue; + } + + // TODO(b/119281462): Cowardly refuse to perform any of the following + // optimizations in the presence of kDomain instructions. It seems that + // modifying a while loop's tuple doesn't work when kDomain is present. + if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kDomain})) { + continue; + } + + // Each of the optimizations below modifies the while loop itself if it's + // successful, meaning that `while_op` is no longer valid after one of these + // transformations returns true. - result = TryRemoveWhileLoop(while_op); - TF_RETURN_IF_ERROR(result.status()); - if (result.ValueOrDie()) { - changed = true; - // Don't try to remove dead while params after successfully removing the - // while loop -- that would result in use-after-free nastiness. + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); + changed |= result; + if (result) { continue; } - result = TryRemoveDeadWhileParams(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; + if (result) { + continue; + } + + TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); + changed |= result; + if (result) { + continue; + } + + bool merged_induction_vars = false; + // Notably missing from this list are S16 and U16. These don't currently + // work because S/U16 literals are not implemented. + for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { + TF_ASSIGN_OR_RETURN(auto* new_while_op, + TryMergeInductionVariables(while_op, elem_ty)); + if (new_while_op) { + while_op = new_while_op; + changed = true; + merged_induction_vars = true; + } + } + if (merged_induction_vars) { + continue; + } } XLA_VLOG_LINES(3, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 0bc5a0107bbcfb3b29a01d593fb79b89a863e49b..a378f179c63c788cd205ddbb784dee0e6b2106d7 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,11 +25,22 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. +// // - A while loop with static trip count of 1 is replaced by its body (sans // loop). +// // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // +// - If the while loop's parameter is a nested tuple, it's flattened to a +// single-level tuple. This is good because it usually reduces the number of +// kTuple instructions, but also because it unlocks additional optimizations +// (e.g. removing unused loop parameters). +// +// Flattening nested while loop tuples adds a whole mess of likely unnecessary +// kGetTupleElement and kTuple operations to the graph. We expect that tuple +// simplifier will be run afterwards. +// class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 1c892ba179ec67ccc9dbfe93d925551d6977ba15..4950e8269e9cf0723d717bd1734518d104c0c9f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,28 +17,45 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { +using ::testing::_; namespace op = xla::testing::opcode_matchers; -class WhileLoopSimplifierTest : public HloVerifiedTestBase { +// Returns the first kWhile instruction within m's entry computation. +HloInstruction* FindFirstWhile(HloModule* m) { + const auto& instrs = m->entry_computation()->instructions(); + return *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); +} + +class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. - void MakeModuleWithSimpleLoop(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoop(int num_iters); // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to // the loop-condition through an element of a tuple which is the // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string_template = R"( HloModule SimpleLoop SimpleLoop.body { @@ -67,10 +84,11 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( int num_iters) { string hlo_string_template = R"( HloModule SimpleLoopWithIndirectLoopBound @@ -104,60 +122,55 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/2); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithControlDependencySimplifiedDependencyPreserved) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -166,9 +179,8 @@ TEST_F(WhileLoopSimplifierTest, // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -179,13 +191,12 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -194,7 +205,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and @@ -202,16 +213,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A non-tuple shaped loop shouldn't be simplified or crash the compiler. @@ -236,8 +246,8 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A while loop that does nothing else besides swapping tuple elements @@ -268,8 +278,8 @@ TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each @@ -297,8 +307,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. @@ -320,8 +330,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't @@ -348,8 +358,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be @@ -390,16 +400,15 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { } )"; - ParseAndVerifyModule(hlo_string); - HloModule* the_module = &module(); - EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. HloInstruction* new_while_op = - *std::find_if(the_module->entry_computation()->instructions().begin(), - the_module->entry_computation()->instructions().end(), + *std::find_if(m->entry_computation()->instructions().begin(), + m->entry_computation()->instructions().end(), [&](const HloInstruction* instr) { return (instr->opcode() == HloOpcode::kWhile && instr->name() != "while"); @@ -440,8 +449,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, @@ -473,8 +482,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { @@ -505,8 +514,233 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { + const string hlo_string = R"( + HloModule Test + Body { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ta = (s32[1]) get-tuple-element(param), index=0 + a = s32[1] get-tuple-element(ta), index=0 + a.1 = s32[1] add(a, a) + tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1 + ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + } + Cond { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + b = s32[2] constant({0,1}) + c = s32[3] constant({0,1,2}) + d = s32[4] constant({0,1,2,3}) + ta = (s32[1]) tuple(a) + td = (s32[4]) tuple(d) + tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td) + init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape flat_tuple = + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") + .ValueOrDie(); + SCOPED_TRACE(m->ToString()); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") + .ValueOrDie())); +} + +// Edge-case: All elements of the loop carry are constants which can be removed, +// leaving us with a nullary loop. This is a special case, we just replace the +// loop with its init. +TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1]) parameter(0) + a = s32[1] constant({0}) + ROOT tuple = (s32[1]) tuple(a) + } + Cond { + param = (s32[1]) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + init = (s32[1]) tuple(a) + ROOT while = (s32[1]) while(init), condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1], s32[2], s32[3]) parameter(0) + a = s32[1] get-tuple-element(param), index=0 + a.1 = s32[1] add(a, a) + b = s32[2] constant({1,1}) + c = s32[3] constant({10,10,10}) + ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c) + } + Cond { + param = (s32[1], s32[2], s32[3]) parameter(0) + /* Use each tuple element. The verifier will then ensure that if any of + * these get modified, they're replaced with values of the correct shape. */ + a = s32[1] get-tuple-element(param), index=0 + b = s32[2] get-tuple-element(param), index=1 + c = s32[3] get-tuple-element(param), index=2 + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + /* Only `b` should be simplified away. `a` is not a constant within the + * loop, and `c`'s value changes depending on whether we run 0 or 1 + * iterations of the loop. */ + a = s32[1] constant({0}) + b = s32[2] constant({1,1}) + c = s32[3] constant({2,2,2}) + init = (s32[1], s32[2], s32[3]) tuple(a,b,c) + ROOT while = (s32[1], s32[2], s32[3]) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + // Run the tuple simplifier to make the resulting HLO a bit easier to check. + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(_, op::Constant(), _)); +} + +const char* const kSimpleMergeInductionVariablesModule = R"( + HloModule Test + Body { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + + a = TYPE[] get-tuple-element(param), index=0 + one = TYPE[] constant(1) + a1 = TYPE[] add(a, one) + + b = TYPE[] get-tuple-element(param), index=1 + negone = TYPE[] constant(-1) + b1 = TYPE[] add(b, negone) + + c = TYPE[] add(a, b) + + ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c) + } + Cond { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + a = TYPE[] get-tuple-element(param), index=0 + b = TYPE[] get-tuple-element(param), index=1 + sum = TYPE[] power(a, b) + ten = TYPE[] constant(10) + ROOT cond = pred[] less-than(sum, ten) + } + ENTRY Loop { + a = TYPE[] constant(10) + b = TYPE[] constant(100) + c = TYPE[] constant(0) + init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c) + while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body + + a1 = TYPE[] get-tuple-element(while), index=0 + b1 = TYPE[] get-tuple-element(while), index=1 + ROOT sum = TYPE[] add(a1, b1) + })"; + +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s32"}}); + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find, and run the tuple simplifier to make the resulting HLO + // easier to check. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + // We should have added a new loop counter for s32[] to the end of the tuple. + SCOPED_TRACE(m->ToString()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + + EXPECT_THAT(new_while->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1), op::Add(), + op::Add(op::GetTupleElement(op::Parameter(), 3), + op::Constant()))); + EXPECT_THAT(new_while->while_condition()->root_instruction(), + op::Lt(op::Power(op::Add(), op::Add()), op::Constant())); +} + +// We shouldn't merge S16 induction variables; we can't create constants of this +// type because S16 literals are not implemented. +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s16"}}); + EXPECT_FALSE( + WhileLoopSimplifier() + .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get()) + .ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1f583ca44b7d20d56f27560f4a97a38c3fcc3026..039ccda7322f5efda6a827efbeda1225c3596cc0 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -270,4 +272,17 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { return result; } +/*static*/ absl::flat_hash_map> +WhileUtil::GetGTEsMapForWhileConditional( + const HloComputation& while_conditional) { + absl::flat_hash_map> result; + for (HloInstruction* user : + while_conditional.parameter_instruction(0)->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + result[user->tuple_index()].push_back(user); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 524dcec5f12689027ef76b8ae180bcbcc7cff601..cba41ccd8b184ba3d867bc170724aee71e777788 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -85,6 +87,13 @@ class WhileUtil { // Assumes `while_body` is the body computation of the while loop in question. static std::vector GetInvariantGTEsForWhileBody( const HloComputation& while_body); + + // Returns a map of index to GetTupleElement instructions in + // `while_conditional` that access elements in the parameter tuple. Assumes + // `while_conditional` is the conditional computation of the while loop in + // question. + static absl::flat_hash_map> + GetGTEsMapForWhileConditional(const HloComputation& while_conditional); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index b9ef18892d7aa859f6b0b505db4c004e4f5c5066..a546a6d39cc55d1f327b8449c7d26cd4c95dbf98 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -45,7 +45,8 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - auto module = CreateNewModule("zero_sized_elimination_test_module"); + auto module = + CreateNewUnverifiedModule("zero_sized_elimination_test_module"); module->AddEntryComputation(builder_.Build()); return ZeroSizedHloElimination{}.Run(module.get()); } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 14c35e7b84f07bebac33a9753ac26a8ee1418f1e..33edbd1b20d01bf132f2a152625d5f49a45f26f9 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,8 +47,11 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Compile(const CompileRequest* arg, + CompileResponse* result) = 0; + + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index df610102b4c7fa08c0b7030124939009130f89f4..7bf97729165bef98fabc29040e02203eee68a53c 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -667,12 +667,11 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; - ForEachElement( - [this, &other, &equal](const ShapeIndex& index, const T& data) { - if (data != other.element(index)) { - equal = false; - } - }); + ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) { + if (data != other.element(index)) { + equal = false; + } + }); return equal; } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c8ff55e7845785d9292516b823fb591cc28cbfad..2b6c484bc4f205be0180403eeac2dd391029b110 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -52,10 +52,10 @@ class ShapeTreeTest : public ::testing::Test { TEST_F(ShapeTreeTest, DefaultConstructor) { ShapeTree int_tree; - EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(int_tree.shape())); ShapeTree bool_tree; - EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(bool_tree.shape())); } void ShapeTreeTest::TestShapeConstructor(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 17120e610cb26dda41fffd28fdb2b9e8bdffb973..7d011bfc658a1f0fc27d93027be355f49966bd62 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -74,6 +74,11 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { return out; } +bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { + return size() >= prefix.size() && + indices_.subspan(0, prefix.size()) == prefix.indices_; +} + namespace { // Returns whether the given primitive type corresponds to an array shape. @@ -367,10 +372,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsTuple(shape) && TupleElementCount(shape) == 0; } -/* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape); -} - /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { CHECK(IsTuple(shape)) << HumanString(shape); return shape.tuple_shapes_size(); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 191ab04759f2d0ae87d988cba0d303f1ab696432..7f72e57d008a71c7aa01262610dfb745641976b7 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -100,6 +101,11 @@ class ShapeIndex { string ToString() const; + template + friend H AbslHashValue(H h, const ShapeIndex& index) { + return H::combine(std::move(h), index.indices_); + } + private: container_type indices_; }; @@ -147,6 +153,9 @@ class ShapeIndexView { string ToString() const; + // Returns true if this shape index starts with 'prefix'. + bool StartsWith(ShapeIndexView prefix) const; + private: absl::Span indices_; }; @@ -465,9 +474,6 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is the nil shape (an empty tuple). - static bool IsNil(const Shape& shape); - // Returns the number of elements in the given tuple shape. // Precondition: IsTuple(shape) static int64 TupleElementCount(const Shape& shape); @@ -751,10 +757,18 @@ class ShapeUtil { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } + tensorflow::mutex mu; + Status status; // Guarded by mu + while (n < rank) { if (pool != absl::nullopt) { - pool->Schedule( - [indexes, &visitor_function] { visitor_function(indexes); }); + pool->Schedule([indexes, &visitor_function, &mu, &status] { + StatusOr result = visitor_function(indexes); + if (!result.ok()) { + tensorflow::mutex_lock lock(mu); + status = status.ok() ? result.status() : status; + } + }); } else { TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); if (!should_continue) { @@ -772,7 +786,9 @@ class ShapeUtil { } } - return Status::OK(); + // Waits for the scheduled work to complete. + pool.reset(); + return status; } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0c647369a37e70f93abe1732963d2ddc7730c214..11b493323cb4a44909bc535d1bbc04fda7506728 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -376,12 +376,12 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { } TEST(ShapeUtilTest, NilShape) { - EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d395c9a4ceecfbd38076ac51f5a18da2ef098abb..20493a354cf486051ec3f47146e48c01a92af83b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,7 +44,7 @@ cc_library( testonly = True, srcs = ["xla_internal_test_main.cc"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -117,12 +117,12 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", @@ -135,50 +135,13 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) -cc_library( - name = "hlo_verified_test_base", - testonly = True, - srcs = ["hlo_verified_test_base.cc"], - hdrs = ["hlo_verified_test_base.h"], - deps = [ - ":hlo_test_base", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "hlo_verified_test_base_test", - srcs = ["hlo_verified_test_base_test.cc"], - deps = [ - ":hlo_test_base", - ":hlo_verified_test_base", - ":test_macros_cpu", - ":test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -335,6 +298,31 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_test", + timeout = "long", + srcs = ["conv_depthwise_test.cc"], + blacklisted_backends = [ + # disabled because of a break b/119590850. + "cpu", + "gpu", + ], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], @@ -868,7 +856,8 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - shard_count = 25, + shard_count = 40, + tags = ["optonly"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c131bfd6a6e6d8f3a929145fa06247c3addc5550..0615f9425c1289d666641f4d581946b44b4895ce 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2478,8 +2478,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { - { 00 }, - { 01 } + { 0, 0 }, + { 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2492,8 +2492,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1100 }, - { 0001 } + { 1, 1, 0, 0 }, + { 0, 0, 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2506,8 +2506,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0100 }, - { 0000 } + { 0, 1, 0, 0 }, + { 0, 0, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2520,8 +2520,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1011 }, - { 1111 } + { 1, 0, 1, 1 }, + { 1, 1, 1, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2534,8 +2534,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0011 }, - { 1110 } + { 0, 0, 1, 1 }, + { 1, 1, 1, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2744,12 +2744,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { -{ { 01 }, - { 00 }, - { 00 } }, -{ { 01 }, - { 10 }, - { 01 } } +{ + { 0, 1 }, + { 0, 0 }, + { 0, 0 } +}, +{ + { 0, 1 }, + { 1, 0 }, + { 0, 1 } +} })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 9966e4606ef7f104487182e0240e64e4c9e4d834..9930bfc95c297093584d427397cac042c296050f 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9811a015e91d866d6f4de6ebb6dac536ed6c7e06..4f5b525a34252db9e967a55af0d1bf39a2dd830e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -492,6 +492,32 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } +XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { + XlaBuilder builder(TestName()); + auto a_literal = LiteralUtil::CreateR1({256.0}); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b = ConcatInDim(&builder, {a, a}, 0); + auto c = ConcatInDim(&builder, {b, b}, 0); + auto d = ConcatInDim(&builder, {c, c}, 0); + auto e = ConcatInDim(&builder, {d, d}, 0); + auto f = ConcatInDim(&builder, {e, e}, 0); + auto g = ConcatInDim(&builder, {f, f}, 0); + auto h = ConcatInDim(&builder, {g, g}, 0); + auto i = ConcatInDim(&builder, {h, h}, 0); + auto j = ConcatInDim(&builder, {i, i}, 0); + auto k = ConcatInDim(&builder, {j, j}, 0); + auto l = ConcatInDim(&builder, {k, k}, 0); + auto m = ConcatInDim(&builder, {l, l}, 0); + auto n = ConcatInDim(&builder, {m, m}, 0); + auto o = ConcatInDim(&builder, {n, n}, 0); + auto p = ConcatInDim(&builder, {o, o}, 0); + auto q = ConcatInDim(&builder, {p, p}, 0); + ConcatInDim(&builder, {q, q}, 0); + std::vector expected(131072, 256.0); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, expected, {a_data.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..60ce576ceb20b89b59e72d821e63b0ccdee51b0b --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct DepthwiseConvolution2DSpec { + int64 output_feature, window, stride, pad, lhs_dilate; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class DepthwiseConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {128, 1, 2, 144}, + {256, 1, 2, 64}, {64, 14, 12, 172}, {16, 9, 4, 16}}; + + for (auto option : config_options) { + int64 feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + + std::vector kernel_layout = {3, 2, 1, 0}; + DepthwiseConvolution2DSpec config; + config.output_feature = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, 1, feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, feature}; + } else if (feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = feature / 32; + config.output_dims = {batch, feature / 32, + activation_size - kernel_size + 1, feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv, is_scheduled=true + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.output_feature); + + } else if (spec.stride == -1) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv, is_scheduled=true + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.output_feature); + } else { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv, is_scheduled=true + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); + } +} + +XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { + const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + hlo_text, ErrorSpec{0.01, 0.01}, [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + DepthwiseConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 3aebf784664dac14ba2ea45c5a229b7b2e4fc39d..b52d30fd6624c26ad62bd0c5f6a6d74175e4539f 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -597,7 +597,692 @@ TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { } template -class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 1024}; + std::vector filter_dims = {3, 3, 1, 1024}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1024); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(4096, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -656,8 +1341,8 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) { this->RunTest(); } @@ -951,6 +1636,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF32ForwardReversed)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f32[3,56,56,16] parameter(0) + %arg1 = f32[3,3,3,32] parameter(1) + ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1407e68d9a336b6bb1c960711015430f872aa912..3622f2c1e84639baed13059b21b20609d1347da6 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); @@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, @@ -143,7 +143,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -175,7 +175,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -209,7 +209,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); Literal result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 001490c6a8c568656437465054ee4db40d0d8dee..738b6442354b01364278e3e3c713aa2cdb5cf47d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -70,7 +70,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -85,7 +85,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -106,7 +106,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -130,7 +130,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = @@ -155,7 +155,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 4d4b676a538947c8dd92a7e34db72e45766cae2c..d1fddf9d6b494a822610e41307fa103dc90bdef3 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -81,7 +81,7 @@ class FusionTest : public HloTestBase { } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -183,7 +183,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +231,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +266,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +290,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +314,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +344,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +359,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +374,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +389,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +404,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +419,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +434,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +449,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +465,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +483,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +501,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +519,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +541,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +559,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,7 +587,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -607,7 +607,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +630,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +682,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7ab2ecda58666acd7e9b8587d200a902b75822f3..989a7c705a8254f99e5cc0e97dfde5942f146964 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -85,6 +85,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + LOG(ERROR) << "Contents of bad module:"; + XLA_LOG_LINES(tensorflow::ERROR, ToString()); + } +} + HloTestBase::HloTestBase(bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, std::function @@ -100,17 +119,40 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool allow_mixed_precision_in_hlo_verifier, std::function instruction_can_change_layout_func) - : test_runner_(test_platform), reference_runner_(reference_platform) { + : test_runner_(test_platform), + reference_runner_(reference_platform), + verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, instruction_can_change_layout_func); } -std::unique_ptr HloTestBase::CreateNewModule(const string& name) { +std::unique_ptr HloTestBase::CreateNewUnverifiedModule( + const string& name) { return absl::make_unique(name, GetModuleConfigForTest()); } +std::unique_ptr HloTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + /* static */ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, HloModule* module) { @@ -135,7 +177,7 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { } DebugOptions HloTestBase::GetDebugOptionsForTest() { - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 217428befa474448cf2dcbae2eb6cb5b0e61d44c..1d1e7f437296a7493ef7da07039fcf6d273f35bc 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -38,6 +39,31 @@ limitations under the License. namespace xla { +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and @@ -72,7 +98,22 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - std::unique_ptr CreateNewModule(const string& name = TestName()); + // + // This returns a vanilla HloModule that doesn't run the HLO verifier on + // destruction. + ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") + std::unique_ptr CreateNewUnverifiedModule( + const string& name = TestName()); + + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -247,6 +288,8 @@ class HloTestBase : public ::testing::Test { HloRunner test_runner_; HloRunner reference_runner_; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc deleted file mode 100644 index 8bd0a729b77f3ec14204952cb0062103c823883e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - } -} - -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), - verifier_layout_sensitive_(layout_sensitive), - allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} - -HloModule& HloVerifiedTestBase::module() { - if (!module_) { - module_ = CreateNewVerifiedModule(TestName()); - } - return *module_; -} - -HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(CreateNewVerifiedModule(name)); - return modules_.back().get(); -} - -void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - CHECK(!module_) << "Called ParseModule when test already has a module."; - module_ = CreateNewVerifiedModule(TestName()); - TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); - module_->VerifyOrAddFailure("after parsing"); -} - -StatusOr> -HloVerifiedTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { - auto module = CreateNewVerifiedModule(TestName()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); - return std::move(module); -} - -std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( - const string& name) { - return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h deleted file mode 100644 index 388a99bb36408665edbc20ade6c6a733d64db88d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ - -#include -#include -#include - -#include "absl/base/macros.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - -// A base class for HLO tests that stores a default VerifiedHloModule. -class HloVerifiedTestBase : public HloTestBase { - protected: - HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - - // Constructs a default shape verifier. - std::unique_ptr MakeShapeVerifier(); - - // Returns the default HloModule, lazily creating it if necessary via - // HloTestBase::CreateNewModule(). - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule& module(); - - ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") - void ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Creates a new module for a test, and stores it in modules_ so it can be - // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent - // creation of unverified modules. - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule* CreateNewModule(const string& name = TestName()); - - // Creates and returns a verified HLO module with the given name. - std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); - - private: - // It is confusing to store modules created by module() and CreateNewModule() - // in different fields, but it allows us to migrate tests to - // HloVerifiedTestBase more easily, so it's a win because we can verify more - // modules. See b/80488902. - // - // Lazily populated. Access via module(). - std::unique_ptr module_; - - // Populated by calls to CreateNewModule. - std::vector> modules_; - - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc deleted file mode 100644 index 5c0263e811f94c90a69a460525ffa0c65127ebb5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// This class includes unit tests which are expected to fail because invalid HLO -// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to -// include the necessary gunit parts to test this test machinery (needs the -// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the -// disabled tests enabled and failures can be manually compared against -// expectations. -class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; - -XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { - // Test shouldn't fail if no module is created at all. -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { - // Use module() to lazily create an empty module, build it up, and verify no - // failures. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { - // Use module() to lazily create an empty module and build up an invalid - // module. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); - - *hlo_module.entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { - // Call CreateNewModule and build up a valid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { - // Call CreateNewModule and build up a invalid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); - - *module->entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndVerifyModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - ParseAndVerifyModule(hlo_string); - EXPECT_EQ(module().entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_EQ(module->entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} - -RANDOM GARBAGE -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleBad - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[1234] add(x,y) -} -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index c622b295094e53e63d0ed692d428bc97724c787c..a78ccacec114858740bf1b9c04e9b688bca5818d 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -68,7 +68,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); compiler->SetPreOptimizationHook(pre_opt_hook); @@ -90,7 +90,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - std::unique_ptr hlo_module = CreateNewModule(); + std::unique_ptr hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto module_group = absl::make_unique("test_module_group"); @@ -124,9 +124,9 @@ class LLVMCompilerTest : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - static std::unique_ptr CreateNewModule() { + static std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index ca7637a0cfa5d837dfd86aadafd1e5cc19ffc22e..3f5135438fc59bea98527b1be30ee49339edd455 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = @@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 58539e6b061b0cec1cc660b52e78894e5deeea56..774eb8d2a85914c52597144e70838ee117ee1134 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { - { 011 }, - { 100 } + { 0, 1, 1 }, + { 1, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 7e1f4aa0eb4801876d9bdbac6a4d7f1d09f81ba8..32de0fdf78f9c442e17c55e1b951e39122dac5ef 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -129,6 +129,42 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_InversePermutation) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + permutation = s32[3,4] parameter(0) + reshape = s32[3,4,1] reshape(permutation) + operand = s32[3,4] iota(), iota_dimension=1 + updates = s32[3,4,1,1] iota(), iota_dimension=1 + iota = s32[3,4,1] iota(), iota_dimension=0 + indices = s32[3,4,2] concatenate(iota, reshape), dimensions={2} + ROOT scatter = s32[3,4] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={2,3}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=2 +} +)"; + Literal permutation = + LiteralUtil::CreateR2({{1, 3, 2, 0}, {3, 0, 2, 1}, {2, 3, 1, 0}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + auto actual = ExecuteAndTransfer(std::move(module), {&permutation}); + Literal expected = + LiteralUtil::CreateR2({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); +} + XLA_TEST_F(ScatterTest, SimpleR4) { const char* hlo_text = R"( HloModule SimpleR4 diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48..3fb69419e735bfd9c5054673e0687f5139a410cb 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -166,6 +166,26 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +TEST_F(SliceTest, SliceOfReshape) { + Array2D values(2 * 3 * 24, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR2FromArray2D(&builder, values); + auto reshape = Reshape(original, {24, 3, 2, 7}); + Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1}); + ComputeAndCompare(&builder, {}); +} + +TEST_F(SliceTest, SliceOfCollapsingReshape) { + Array4D values(2, 3, 5, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR4FromArray4D(&builder, values); + auto reshape = Reshape(original, {2 * 3 * 5, 7}); + Slice(reshape, {0, 0}, {4, 7}, {1, 1}); + ComputeAndCompare(&builder, {}); +} + XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { Array4D values(2, 4, 6, 8); values.FillRandom(3.14f); @@ -253,7 +273,6 @@ XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } - // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b34fd0f2e873214c509533f29553af914ddc984d..a2b7c26331b3cc89ed0413efe8eb31c2b9e37038 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -28,7 +28,7 @@ namespace { class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateToken()); @@ -38,8 +38,22 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } +XLA_TEST_F(TokenHloTest, TokenInTuple) { + std::unique_ptr module = CreateNewUnverifiedModule(); + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateTuple({token})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + Literal token_literal = LiteralUtil::CreateToken(); + EXPECT_TRUE( + LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal}))); +} + XLA_TEST_F(TokenHloTest, TokenTree) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); @@ -54,7 +68,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); @@ -75,7 +89,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateParameter( 0, @@ -95,7 +109,7 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 376559500efad6a756f8a0f60f0a522db047c0e5..ca036f1ae0d5e31a3f83d9d31c80e070c2a666df 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -91,8 +91,8 @@ Status ParseOneProfileOutputLine( string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; - string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; - string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?"; + string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?"; // The underlined part is what we're trying to match with match_opcode: // @@ -307,6 +307,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, matrix_shape); + SCOPED_TRACE(profile_output); std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); @@ -318,14 +319,13 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = std::find_if( - while_body_profile_start, profile_output_lines.end(), - [](absl::string_view s) { - return absl::StartsWith(s, "********** microseconds report **********"); - }); + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds "); + }); - // We emit a blank line before the "********** microseconds report **********" - // line. + // We emit a blank line before the "microseconds report" line. while_body_profile_end--; ASSERT_NE(while_body_profile_end, profile_output_lines.end()); @@ -380,7 +380,7 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 15603619b62d8f45cdce97ac7d83924a78f88cf3..dca0aa52a533130372759156a3238f1a3b10ca42 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -15,14 +15,14 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { LOG(ERROR) << "\n" << usage; @@ -49,7 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { // different API than Tensorflow's. testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) - base::SetFlag(&FLAGS_benchmarks, pattern); + absl::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); #else tensorflow::testing::Benchmark::Run(pattern); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 3a086c66bbb37965b1ad7c83a93f0054ae723e87..8926bbed2b54fceaaf0e6e991f0e881d35731ef4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -33,6 +33,7 @@ cc_library( name = "dumped_computation_to_graphviz_library", srcs = ["dumped_computation_to_graphviz.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -40,7 +41,6 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -78,6 +78,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -91,7 +92,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -207,13 +207,13 @@ tf_cc_binary( name = "dumped_computation_to_tf_graphdef", srcs = ["dumped_computation_to_tf_graphdef.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index c866a13de7543fc948311f94708bc6b904717b62..b623556468fb4a5d96be614b6c067d5a1df51a6f 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,7 +54,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = client->GetComputationStats(computation, debug_options) @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 07ef5ff656bb48519a700a1d7d6c60b655a40ed6..f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,7 +53,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 109411f99b6eb000474b0c61783c51f42d43bb6d..47be9f5adf1063463d7678579a7f394684aaf357 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -47,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -191,8 +191,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, // Run the computation num_runs times, and return the result from the last // execution. - const bool xla_hlo_profile = - legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 8ce741647414a1fa75e6d706ec1e719ace7b7cc8..b015f4328a15473db862b753c907975856383a79 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -387,6 +387,19 @@ T CeilOfRatio(T dividend, T divisor) { return tensorflow::MathUtil::CeilOfRatio(dividend, divisor); } +template +std::vector ElementWiseCeilOfRatio(absl::Span dividends, + absl::Span divisors) { + std::vector ceil_of_ratios; + CHECK_EQ(dividends.size(), divisors.size()); + ceil_of_ratios.reserve(dividends.size()); + absl::c_transform(dividends, divisors, std::back_inserter(ceil_of_ratios), + [](const T dividend, const T divisor) { + return CeilOfRatio(dividend, divisor); + }); + return ceil_of_ratios; +} + // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 8ea8dbab2574ca1e24271e7c1c7762d4a6b6a8de..51c73b3d17e4c32d9a8a14d3055ab56f02922af3 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -185,6 +185,17 @@ bool HasWindowReversal(const Window& window) { return false; } +bool AllOrNoneReversed(const Window& window) { + if (window.dimensions().empty()) { + return true; + } + bool reversed = window.dimensions()[0].window_reversal(); + return std::all_of(window.dimensions().begin(), window.dimensions().end(), + [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 1fb9e855fc16f334eb0e83dfd27b307b2149628f..099d7ecdd5c732ffc8c6ff6370288a2fc4144fa2 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -56,6 +56,7 @@ bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); bool HasWindowReversal(const Window& window); +bool AllOrNoneReversed(const Window& window); // Returns true if the given logical dimension is inactive in the sense that it // has window bound 1, no striding and no padding. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 65948ef4b0c3d51805b15634e6215f192e740aaa..28df3b03f398841460189910bc3a5096dfb0d367 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -322,6 +322,34 @@ message UnregisterRequest { message UnregisterResponse { } +message CompileRequest { + // The graph to be compiled. + HloModuleProto computation = 1; + + // Options that affect how XLA compiles code to service this request. + ExecutionOptions execution_options = 2; + + // The layouts of the input arguments. If not set, the default layout will be + // used. Although the real arguments are not needed in compilation, the + // layouts of the arguments can affect the compilation. + repeated Shape input_shape_with_layout = 3; +} + +message CompileResponse { + // The handle to the executable. + ExecutionHandle handle = 1; +} + +message ExecuteRequest { + ExecutionHandle handle = 1; + + // The shape and layout of the arguments must be the same as the those of the + // executable's parameters. + repeated GlobalDataHandle arguments = 2; +} + +// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace +// the uses with calls to Compile and Execute. message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b6bd919e2b26a109cb9dfd2a6aaba86f1732cff1..683ccc40f162ead3a248aee83d9abf3086a1ac93 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -332,11 +332,13 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - // The F16s and BF16s are encoded in little endian byte order + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 16 + // Next = 18 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 9e3d2454d16730c1d1f93cb384db88544380f77e..67f475846e5f16060c1080759b0acb4216c4e72b 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["xrt_state_ops.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -21,7 +22,6 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 5678f0905ff5b8956e0811026e7450acba8815e9..6ab77fbaaf0cbe23503ebc71775f52af01e41a74 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -6,6 +6,24 @@ import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; +message DeviceAssignment { + message ComputationDevice { + message DeviceMeshCoordinates { + // The mesh coordinates for the device. Usually (X, Y, Core), in the order + // in which they are returned in the TopologyProto. + // X = value(0) + // Y = value(1) + // Core = value(2) + repeated int32 value = 1; + } + // As many replicas as there are in the replicated computation. + repeated DeviceMeshCoordinates replica_devices = 1; + } + // As many ComputationDevice as many there are computations (number + // of cores per replica). + repeated ComputationDevice computation_devices = 1; +} + // Options for an XLA compilation. message XLAComputationConfig { // The number of replicas the computation will be run on. If this is @@ -23,6 +41,11 @@ message XLAComputationConfig { // computation. per_core_args_and_result_shapes is optional for a // single-core computation. repeated xla.ProgramShape per_core_program_shape = 5; + // Describes how replicated computation instances should be assigned to + // devices. There are num_cores_per_replica computations, and each one will be + // sent and executed to the set of replica device numbers described in the + // DeviceAssignment proto. + DeviceAssignment device_assignment = 6; } // Options and XLA computation for a compilation. diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index a513aa1e7c49d64a860c740fffde156fb5bcbcf3..f6c6560c1c354ed8a36b98b1f564835eb9958e55 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -9,8 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_py_test") - py_library( name = "all_reduce_py", srcs = ["__init__.py"], @@ -29,29 +27,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - -tf_py_test( - name = "all_reduce_test", - srcs = ["python/all_reduce_test.py"], - additional_deps = [ - ":all_reduce", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:state_ops", + "//tensorflow/python/distribute:all_reduce", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 25f4b4b8d341331db79321338a88cabfe325eea5..238cdaf8a79812df3f043d9d070bbcfd443f6e1e 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,842 +18,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import math - -from tensorflow.python.framework import device as device_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nccl_ops - - -def _flatten_tensors(tensors): - """Check tensors for isomorphism and flatten. - - Args: - tensors: list of T `tf.Tensor` which must all have the same shape. - - Returns: - tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors - shape: the original shape of each element of input tensors - - Raises: - ValueError: tensors are empty or non-isomorphic or have unknown shape. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - for tensor in tensors: - shape = shape.merge_with(tensor.shape) - if not shape.is_fully_defined(): - raise ValueError("Tensors must have statically known shape.") - if len(shape) != 1: - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, [-1])) - tensors = reshaped - return tensors, shape - - -def _reshape_tensors(tensors, shape): - """Reshape tensors flattened by _flatten_tensors. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - shape: list of integers describing the desired shape. Product of - the elements must equal the length of each tensor. - - Returns: - list of T `tf.Tensor` which are the reshaped inputs. - """ - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, shape)) - return reshaped - - -def _padded_split(tensor, pieces): - """Like split for 1D tensors but pads-out case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - thin input tensor, in order. The final tensor may - be zero-padded on the end to make its size equal to those of all - of the other tensors. - - Raises: - ValueError: The input tensor is not 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - with ops.colocate_with(tensor): - if tensor_len % pieces != 0: - # pad to an even length - chunk_size = 1 + tensor_len // pieces - if pieces > tensor_len: - # This is an edge case that should not come up in practice, - # i.e. a different reduction algorithm would be better, - # but we'll make it work just for completeness. - pad_len = pieces - tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - elif (pieces - 1) * chunk_size >= tensor_len: - # Another edge case of limited real interest. - pad_len = (pieces * chunk_size) % tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - else: - last_chunk_size = tensor_len - (pieces - 1) * chunk_size - pad_len = chunk_size - last_chunk_size - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - parts = array_ops.split(tensor, piece_lens) - parts[-1] = array_ops.concat( - [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - return parts, pad_len - else: - return array_ops.split(tensor, pieces), 0 - - -def _strip_padding(tensors, pad_len): - """Strip the suffix padding added by _padded_split. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - pad_len: number of elements to be stripped from the end of each tensor. - - Returns: - list of T `tf.Tensor` which are the stripped inputs. - - Raises: - ValueError: tensors must be a non-empty list of 1D tensors, and - each must be longer than pad_len. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - if len(shape) > 1: - raise ValueError("tensors must be 1D") - prefix_len = int(shape[0] - pad_len) - if prefix_len < 0: - raise ValueError("pad_len longer than tensor") - stripped = [] - for t in tensors: - with ops.colocate_with(t): - stripped.append(array_ops.slice(t, [0], [prefix_len])) - return stripped - - -def _ragged_split(tensor, pieces): - """Like split for 1D tensors but allows case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - the input tensor, in order. The final tensor may be shorter - than the others, which will all be of equal length. - - Raises: - ValueError: input tensor must be 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape.dims[0].value - chunk_size = tensor_len // pieces - with ops.colocate_with(tensor): - if tensor_len != (pieces * chunk_size): - # last piece will be short - assert pieces > 1 - last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) - assert last_chunk_size > 0 - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - return array_ops.split(tensor, piece_lens) - else: - return array_ops.split(tensor, pieces) - - -def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for each subchunk. - - In the basic ring reduction algorithm there are size(T)/num_devices - data chunks and each device process one chunk per tick, i.e. sending - one chunk and receiving one chunk. The idea of subchunking is that - each device processes num_subchunks smaller data regions per tick, - and the ring rank permutation is different for each subchunk index - so that a device is potentially sending to and receiving from - num_subchunks different other devices at each tick. Where multiple - independent data channels exist between devices, this strategy - supplies a method of using them in parallel. - - Args: - num_workers: number of worker tasks - num_subchunks: number of subchunks into which to divide each per-GPU chunk. - gpu_perm: an array of integers in [0, num_gpus-1] giving the default - ring order of GPUs at each worker. Other permutations will be generated - by rotating this array and splicing together per-worker instances. - - Raises: - ValueError: the number of subchunks may not exceed the number of GPUs. - - Returns: - pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - preceding device in the permutation for that subchunk. The - device index of GPU i at worker j is i + (j * num_gpus). - rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - local rank of device d in the permutation for that subchunk. - """ - num_gpus = len(gpu_perm) - devices = num_workers * num_gpus - if devices == 0: - return [], [] - if num_subchunks > num_gpus: - raise ValueError( - "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) - rotation_interval = max(1, int(num_gpus / num_subchunks)) - perms_by_s = [] - for s in range(0, num_subchunks): - full_order = [] - offset = s * rotation_interval - for w in range(0, num_workers): - default_order = [(w * num_gpus) + i for i in gpu_perm] - dev_order = default_order[offset:] + default_order[:offset] - full_order += dev_order - perms_by_s.append(full_order) - pred_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - rank_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - for s in range(0, num_subchunks): - for d in range(0, devices): - for t in range(0, devices): - if d == perms_by_s[s][t]: - rank_by_s_d[s][d] = t - pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] - break - return (pred_by_s_d, rank_by_s_d) - - -def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, - gpu_perm, red_op, un_op=None): - """Construct a subgraph performing a ring-style all-reduce of input_tensors. - - Args: - input_tensors: a list of T `tf.Tensor` objects, which must all - have the same shape and type. - num_workers: number of worker tasks spanned by input_tensors. - num_subchunks: number of subchunks each device should process in one tick. - gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at - each worker. All workers must have the same number of - GPUs with the same rank ordering. If NVLINK is available, this should - be a ring order supported by NVLINK edges. - red_op: a binary operator for elementwise reduction. - un_op: an optional unary operator to apply to fully reduced values. - - Raises: - ValueError: empty input_tensors or they don't all have same - size. - - Returns: - a list of T `tf.Tensor` identical sum-reductions of input_tensors. - """ - if len(input_tensors) < 2: - raise ValueError("input_tensors must be length 2 or longer") - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - (pred_by_s_d, rank_by_s_d) = _ring_permutations( - num_workers, num_subchunks, gpu_perm) - chunks_by_dev, pad_len = _build_ring_gather( - input_tensors, devices, - num_subchunks, pred_by_s_d, rank_by_s_d, red_op) - if un_op: - chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) - output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev) - if pad_len > 0: - output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_ring_gather(input_tensors, devices, num_subchunks, - pred_by_s_d, rank_by_s_d, red_op): - """Construct a subgraph for the first (reduction) pass of ring all-reduce. - - Args: - input_tensors: a list of T `tf.Tensor` 1D input tensors of same - shape and type. - devices: array of device name strings - num_subchunks: number of subchunks each device should process in one tick. - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - red_op: a binary operator for elementwise reduction - - Raises: - ValueError: tensors must all be one dimensional. - - Returns: - list of list of T `tf.Tensor` of (partially) reduced values where - exactly num_subchunks chunks at each device are fully reduced. - """ - num_devices = len(input_tensors) - if num_devices == 0: - return [] - if num_devices == 1: - return input_tensors - shape = input_tensors[0].shape - if 1 != len(shape): - raise ValueError("input tensors must be 1D") - num_chunks = num_devices * num_subchunks - num_ticks = num_devices - 1 - # Initialize chunks_by_dev with splits of the input tensors. - chunks_by_dev = [] - split_pad_len = 0 - for d in range(0, num_devices): - with ops.device(devices[d]): - splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) - chunks_by_dev.append(splits) - # Reduction phase - for tick in range(0, num_ticks): - # One new partial reduction for every chunk - new_partial_reductions = [None for _ in range(0, num_chunks)] - # Compute reductions with respect to last tick's values - for d in range(0, num_devices): - with ops.device(devices[d]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - new_partial_reductions[chunk_index] = red_op( - chunks_by_dev[pred_dev][chunk_index], - chunks_by_dev[d][chunk_index]) - # Update chunks_by_dev with the new values at the end of the tick. - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] - return chunks_by_dev, split_pad_len - - -def _apply_unary_to_chunks(f, chunks_by_dev): - """Apply a unary op to each tensor in chunks_by_dev, on same device. - - Args: - f: a unary function over T `tf.Tensor`. - chunks_by_dev: list of lists of T `tf.Tensor`. - - Returns: - new list of lists of T `tf.Tensor` with the same structure as - chunks_by_dev containing the derived tensors. - """ - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append([f(t) for t in x]) - return output - - -def _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev): - """Construct subgraph for second (scatter) pass of ring all-reduce. - - Args: - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T `tf.Tensor` indexed by ints - (device, chunk) - - Raises: - ValueError: chunks_by_dev is not well-formed - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device corresponding to the outer dimension of chunks_by_dev. - """ - num_devices = len(chunks_by_dev) - num_chunks = len(chunks_by_dev[0]) - if 0 != num_chunks % num_devices: - raise ValueError( - "Expect number of chunks per device to be divisible by num_devices") - num_subchunks = int(num_chunks / num_devices) - num_ticks = num_devices - 1 - for tick in range(0, num_ticks): - passed_values = [None for _ in range(0, num_chunks)] - for d in range(0, num_devices): - with ops.colocate_with(chunks_by_dev[d][0]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - passed_values[chunk_index] = array_ops.identity( - chunks_by_dev[pred_dev][chunk_index]) - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = passed_values[chunk_index] - # Join chunks at each device. - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append(array_ops.concat(x, 0)) - return output - - -def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): - """Construct a subgraph for recursive halving-doubling all-reduce. - - The recursive halving-doubling algorithm is described in - http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf - - The concept is to arrange the participating n devices in - a linear sequence where devices exchange data pairwise - with one other device in each round. During the gather - phase there are lg(n) rounds where devices exchange - increasingly smaller sub-tensors with another device - at increasingly greater distances, until at the top - each device has 1/n of the fully reduced values. During the - scatter phase each device exchanges its fully reduced - sub-tensor (which doubles in length at each round) - with one other device at increasingly smaller distances - until each device has all of the fully reduced values. - - Note: this preliminary version requires that len(input_tensors) be a - power of 2. TODO(tucker): relax this restriction. Also, the - number of elements in each tensor must be divisible by 2^h where h - is the number of hops in each phase. This will also be relaxed in - the future with edge-case specific logic. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - red_op: a binary elementwise reduction Op. - un_op: an optional unary elementwise Op to apply to reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device of input_tensors. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - devices = [t.device for t in input_tensors] - input_tensors, shape = _flatten_tensors(input_tensors) - reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) - if un_op: - reduced_shards = [un_op(t) for t in reduced_shards] - output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_recursive_hd_gather(input_tensors, devices, red_op): - """Construct the gather phase of recursive halving-doubling all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - devices: a list of strings naming the devices hosting input_tensors, - which will also be used to host the (partial) reduction values. - red_op: a binary elementwise reduction Op. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensor shards. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - if num_devices != (2 ** num_hops): - raise ValueError("num_devices must be a power of 2") - chunks = input_tensors - for h in range(0, num_hops): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_dev = devices[d] - right_dev = devices[d + span] - left_split = array_ops.split(chunks[d], 2) - right_split = array_ops.split(chunks[d+span], 2) - with ops.device(left_dev): - new_chunks[d] = red_op(left_split[0], right_split[0]) - with ops.device(right_dev): - new_chunks[d + span] = red_op(left_split[1], right_split[1]) - chunks = new_chunks - return chunks - - -def _build_recursive_hd_scatter(input_tensors, devices): - """Construct the scatter phase of recursive halving-doublng all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` that are fully-reduced shards. - devices: a list of strings naming the devices on which the reconstituted - full tensors should be placed. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" - chunks = input_tensors - for h in reversed(range(0, num_hops)): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_idx = d - right_idx = d + span - left_dev = devices[left_idx] - right_dev = devices[right_idx] - with ops.device(left_dev): - new_chunks[left_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - with ops.device(right_dev): - new_chunks[right_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - chunks = new_chunks - return chunks - - -def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): - """Construct a subgraph for shuffle all-reduce. - - Shuffle reduce is essentially the algorithm implemented when using - parameter servers. Suppose tensor length is n, there are d devices - and g gather shards. Each device sends a n/g length sub-tensor to - each gather shard. The gather shards perform a reduction across d - fragments, then broadcast the result back to each device. The - devices then join the g fully reduced fragments they receive from - the shards. The gather shards could perform d-1 pairwise - reductions, or one d-way reduction. The first is better where - reduction Op time is low compared to transmission time, the second - better in the other case. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: an n-array elementwise reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - dst_devices = [t.device for t in input_tensors] - reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, - red_op, un_op) - output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): - """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: the binary reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced shards. - - Raises: - ValueError: inputs not well-formed. - """ - num_source_devices = len(input_tensors) - num_gather_devices = len(gather_devices) - shape = input_tensors[0].shape - if len(shape) != 1: - raise ValueError("input_tensors must be 1D") - shards_by_source = [] - for d in range(0, num_source_devices): - with ops.colocate_with(input_tensors[d]): - shards_by_source.append( - _ragged_split(input_tensors[d], num_gather_devices)) - reduced_shards = [] - for d in range(0, num_gather_devices): - with ops.device(gather_devices[d]): - values = [s[d] for s in shards_by_source] - red_shard = red_op(values) - if un_op: - red_shard = un_op(red_shard) - reduced_shards.append(red_shard) - return reduced_shards - - -def _build_shuffle_scatter(reduced_shards, dst_devices): - """Build the scatter phase of shuffle all-reduce. - - Args: - reduced_shards: list of T @(tf.Tensor} fully reduced shards - dst_devices: list of names of devices at which the fully-reduced value - should be reconstituted. - - Returns: - list of T `tf.Tensor` scattered tensors. - """ - num_devices = len(dst_devices) - out_tensors = [] - for d in range(0, num_devices): - with ops.device(dst_devices[d]): - out_tensors.append(array_ops.concat(reduced_shards, 0)) - return out_tensors - - -def _split_by_task(devices, values): - """Partition devices and values by common task. - - Args: - devices: list of device name strings - values: list of T `tf.tensor` of same length as devices. - - Returns: - (per_task_devices, per_task_values) where both values are - lists of lists with isomorphic structure: the outer list is - indexed by task, and the inner list has length of the number - of values belonging to that task. per_task_devices contains - the specific devices to which the values are local, and - per_task_values contains the corresponding values. - - Raises: - ValueError: devices must be same length as values. - """ - num_devices = len(devices) - if num_devices != len(values): - raise ValueError("len(devices) must equal len(values)") - per_task_devices = collections.OrderedDict() - per_task_values = collections.OrderedDict() - for d in range(num_devices): - d_spec = device_lib.DeviceSpec.from_string(devices[d]) - if not hasattr(d_spec, "task") or d_spec.task is None: - assert False, "failed to parse device %s" % devices[d] - index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) - if index not in per_task_devices: - per_task_devices[index] = [] - per_task_values[index] = [] - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - - return (list(per_task_devices.values()), list(per_task_values.values())) - - -def build_nccl_all_reduce(input_tensors, red_op, un_op=None): - """Build a subgraph that does one full all-reduce, using NCCL. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. Must be one of - {tf.add} - un_op: optional unary elementwise Op to apply to fully-reduce values. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: red_op not supported. - """ - if red_op == math_ops.add: - output_tensors = nccl_ops.all_sum(input_tensors) - else: - raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) - if un_op: - un_op_wrapped = [] - for t in output_tensors: - with ops.colocate_with(t): - un_op_wrapped.append(un_op(t)) - output_tensors = un_op_wrapped - return output_tensors - - -def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): - """Construct a subgraph for NCCL hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [None for w in range(0, num_workers)] - up_devices = up_values[:] - down_values = up_values[:] - # First stage: reduce within each worker using NCCL - for w in range(0, num_workers): - worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) - # NOTE: these reductions will not run to completion unless - # every output value is used. Since we only need one, we - # need to put control dependencies on the rest. - with ops.control_dependencies(worker_values): - with ops.device(worker_values[0].device): - up_values[w] = array_ops.identity(worker_values[0]) - up_devices[w] = per_worker_devices[w][0] - # Second stage: Apply upper_level_f to reduce across first device at - # each worker - level_2_output = upper_level_f(up_values) - # Third stage: propagate within each worker using NCCL Broadcast - for w in range(0, num_workers): - dst_tensors = [] - with ops.device(per_worker_devices[w][0]): - broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) - for d in per_worker_devices[w]: - with ops.device(d): - dst_tensors.append(array_ops.identity(broadcast_src)) - down_values[w] = dst_tensors - output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tensors has more than one element apply red_f, else apply un_op.""" - if len(input_tensors) > 1: - return red_f(input_tensors) - else: - if not un_op: - return input_tensors - output_tensors = [] - for t in input_tensors: - with ops.colocate_with(t): - output_tensors.append(un_op(t)) - return output_tensors - - -def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Ring across workers.""" - def upper_builder(y): - return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) - def upper_level_f(x): - return _reduce_non_singleton(x, upper_builder, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" - upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, - shuffle_red_op, un_op=None): - """Construct hybrid of NCCL within workers, Shuffle across workers.""" - upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices, - shuffle_red_op, un_op) - return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) - - -def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): - """Construct a subgraph for Shuffle hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - gather_devices: list of device names on which to host gather shards. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - # First stage, reduce across each worker using gather_devices. - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [] - if len(gather_devices) != num_workers: - raise ValueError("For shuffle hybrid, gather_devices must contain one " - "device per worker. ") - for w in range(0, num_workers): - reduced_shards = _build_shuffle_gather( - per_worker_values[w], [gather_devices[w]], red_op) - up_values.append(reduced_shards[0]) - # Second stage, apply upper_level_f. - level_2_output = upper_level_f(up_values) - # Third stage, apply shuffle scatter at each worker. - output_tensors = [] - for w in range(0, num_workers): - output_tensors += _build_shuffle_scatter( - [level_2_output[w]], per_worker_devices[w]) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Ring across workers.""" - def upper_builder(tensors): - return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, gather_devices, red_n_op, upper_level_f) - - -def build_shuffle_then_shuffle(input_tensors, first_gather_devices, - second_gather_devices, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Shuffle across workers.""" - def upper_builder(tensors): - return build_shuffle_all_reduce(tensors, second_gather_devices, - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, first_gather_devices, red_op, upper_level_f) +# pylint: disable=unused-import,wildcard-import +from tensorflow.python.distribute.all_reduce import * diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6d2d70c99b4cc804f2c8bf57afdc8c11f1f73516 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark") + +py_library( + name = "benchmark_base", + srcs = [ + "benchmark_base.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cartpole_benchmark", + size = "enormous", + srcs = ["cartpole_benchmark.py"], + tags = [ + "local", + "manual", + "no_oss", + "notap", + "nozapfhahn", + ], + deps = [ + ":benchmark_base", + # Note: required gym dependency may need to be added here. + ], +) + +tf_py_logged_benchmark( + name = "cartpole_logged_benchmark", + target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark", +) diff --git a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py new file mode 100644 index 0000000000000000000000000000000000000000..93c694849c4dc3faca71e7f9d8614649a7784f99 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common benchmarking code. + +See https://www.tensorflow.org/community/benchmarks for usage. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +import tensorflow as tf + + +class ReportingBenchmark(tf.test.Benchmark): + """Base class for a benchmark that reports general performance metrics. + + Subclasses only need to call one of the _profile methods, and optionally + report_results. + """ + + def time_execution(self, name, target, iters, warm_up_iters=5): + for _ in range(warm_up_iters): + target() + + all_times = [] + for _ in range(iters): + iter_time = time.time() + target() + all_times.append(time.time() - iter_time) + + avg_time = np.average(all_times) + + extras = dict() + extras['all_times'] = all_times + + if isinstance(name, tuple): + extras['name'] = name + name = '_'.join(str(piece) for piece in name) + + self.report_benchmark( + iters=iters, wall_time=avg_time, name=name, extras=extras) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..4f553be58e94f11e45f0697558348fbbd26bfb91 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py @@ -0,0 +1,492 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A basic RL cartpole benchmark. + +The RL model uses the OpenAI Gym environment to train a simple network using +the policy gradients method. The training scales the gradients for each step +by the episode's cumulative discounted reward and averages these gradients over +a fixed number of games before applying the optimization step. + +For benchmarking purposes, we replace the OpenAI Gym environment to a fake +that returns random actions and rewards and never ends the episode. This way +the benchmarks compare the same amount of computation at each step. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import eager +from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base +from tensorflow.python import autograph as ag +from tensorflow.python.eager import context + +# +# AutoGraph implementation +# + + +@ag.convert() +def graph_append_discounted_rewards(destination, rewards, discount_rate): + """Discounts episode rewards and appends them to destination.""" + ag.set_element_type(rewards, tf.float32) + + cdr = 0.0 + reverse_discounted = [] + ag.set_element_type(reverse_discounted, tf.float32) + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + cdr.set_shape(()) + reverse_discounted.append(cdr) + + retval = destination + # Note: AutoGraph doesn't yet support reversed() so we use a loop instead. + for i in range(len(reverse_discounted) - 1, -1, -1): + retval.append(reverse_discounted[i]) + + return retval + + +class GraphPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(GraphPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + # TODO(mdan): Move this method out of the class. + @ag.convert() + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + var_list = tf.trainable_variables() + grad_list = [ + tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list + ] + + step_counts = [] + discounted_rewards = [] + ag.set_element_type(discounted_rewards, tf.float32) + ag.set_element_type(step_counts, tf.int32) + + # Note: we use a shared object, cart_pole_env here. Because calls to the + # object's method are made through py_func, TensorFlow cannot detect its + # data dependencies. Hence we must manually synchronize access to it + # and ensure the control dependencies are set in such a way that + # calls to reset(), take_one_step, etc. are made in the correct order. + sync_counter = tf.constant(0) + + for _ in tf.range(num_games): + with tf.control_dependencies([sync_counter]): + obs = cart_pole_env.reset() + with tf.control_dependencies([obs]): + sync_counter += 1 + + game_rewards = [] + ag.set_element_type(game_rewards, tf.float32) + + for step in tf.range(max_steps_per_game): + logits, actions = self(obs) # pylint:disable=not-callable + logits = tf.reshape(logits, ()) + actions = tf.reshape(actions, ()) + + labels = 1.0 - tf.cast(actions, tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + grads = tf.gradients(loss, var_list) + + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + with tf.control_dependencies([sync_counter]): + obs, reward, done = cart_pole_env.step(actions) + with tf.control_dependencies([obs]): + sync_counter += 1 + obs = tf.reshape(obs, (1, 4)) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = graph_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = ag.stack(discounted_rewards) + discounted_rewards.set_shape((None,)) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = ag.stack(grad_list[i]) + + # This block just adjusts the shapes to match for multiplication. + r = normalized_rewards + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return ag.stack(step_counts) + + +@ag.convert() +def graph_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + i = tf.constant(0) + mean_steps_per_iteration = [] + ag.set_element_type(mean_steps_per_iteration, tf.int32) + + while i < iterations: + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + i += 1 + + return ag.stack(mean_steps_per_iteration) + + +class GraphGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ()) + obs = tf.reshape(obs, (1, 4)) + obs = tf.cast(obs, tf.float32) + return obs + + def step(self, actions): + + def take_one_step(actions): + obs, reward, done, _ = self.env.step(actions) + obs = obs.astype(np.float32) + reward = np.float32(reward) + return obs, reward, done + + return ag.utils.wrap_py_func(take_one_step, + (tf.float32, tf.float32, tf.bool), (actions,)) + + +class GraphRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return tf.random.normal((1, 4)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = tf.random.normal((1, 4)) + fixed_reward = tf.constant(0.001) + done = tf.constant(False) + return random_obs, fixed_reward, done + + +# +# Eager implementation +# + + +def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate): + cdr = 0.0 + reverse_discounted = [] + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + reverse_discounted.append(cdr) + + discounted_rewards.extend(reversed(reverse_discounted)) + return discounted_rewards + + +class EagerPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(EagerPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + self._grad_fn = eager.implicit_gradients( + self._get_cross_entropy_and_save_actions) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + def _get_cross_entropy_and_save_actions(self, inputs): + logits, actions = self(inputs) # pylint:disable=not-callable + self._current_actions = actions + labels = 1.0 - tf.cast(actions, tf.float32) + return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + grad_list = None + + step_counts = [] + discounted_rewards = [] + + for _ in range(num_games): + obs = cart_pole_env.reset() + + game_rewards = [] + + for step in range(max_steps_per_game): + grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32)) + grads, var_list = zip(*grads_and_vars) + actions = self._current_actions.numpy()[0][0] + + if grad_list is None: + grad_list = [[g] for g in grads] + else: + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + obs, reward, done = cart_pole_env.step(actions) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = eager_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = tf.stack(discounted_rewards) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = tf.stack(grad_list[i]) + + r = normalized_rewards + while r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return tf.stack(step_counts) + + +def eager_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + mean_steps_per_iteration = [] + + for _ in range(iterations): + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + + return mean_steps_per_iteration + + +class EagerGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + return self.env.reset() + + def step(self, actions): + obs, reward, done, _ = self.env.step(actions) + return obs, reward, done + + +class EagerRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return np.random.normal(size=(4,)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = np.random.normal(size=(4,)) + fixed_reward = 0.001 + done = False + return random_obs, fixed_reward, done + + +def graph_demo_training(): + """Not used in the benchmark. Used to confirm a functional model.""" + with tf.Graph().as_default(): + tf.set_random_seed(0) + + network = GraphPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = GraphGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + train_ops = graph_train_model(network, env, opt, iterations=5) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + steps_per_iteration = sess.run(train_ops) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +def eager_demo_training(): + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = EagerGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + steps_per_iteration = eager_train_model(network, env, opt, iterations=5) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark): + """Actual benchmark. + + Trains the RL agent a fixed number of times, on random environments that + result in constant number of steps. + """ + + def benchmark_cartpole(self): + + def train_session(sess, ops): + return lambda: sess.run(ops) + + def train_eager(network, env, opt): + return lambda: eager_train_model(network, env, opt, iterations=10) + + for model_size in (10, 100, 1000): + with tf.Graph().as_default(): + network = GraphPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = GraphRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + train_ops = graph_train_model(network, env, opt, iterations=10) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + self.time_execution(('cartpole', 'autograph', model_size), + train_session(sess, train_ops), 20) + + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = EagerRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + self.time_execution(('cartpole', 'eager', model_size), + train_eager(network, env, opt), 20) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 55faad983f2bcf2f3fa633669bd371608e2e925b..3e4d0dc1cec76b068c1c846eb476eec615e4f613 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,8 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -101,12 +102,15 @@ def batch_function(num_batch_threads, def decorator(fn): # pylint: disable=missing-docstring def decorated(*args): # pylint: disable=missing-docstring - types = [arg.dtype for arg in args] - @function.Defun(*types) + @function.defun() def computation(*computation_args): return fn(*computation_args) + computation = computation.get_concrete_function( + *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) + for i, x in enumerate(args)]) + with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): @@ -123,7 +127,7 @@ def batch_function(num_batch_threads, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) + Tout=[o.dtype for o in computation.outputs]) return decorated diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 01ee8703a93836d607ee9b765c51c79fe3bb974f..9109b9c1c91cefa4c52bad49de23336a6e05e1ef 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -219,6 +219,7 @@ class BatchOpsTest(test.TestCase): @batch_ops.batch_function(1, 10, 100000) def computation(in_t): + self.assertTrue(in_t.shape is not None) return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 14b6fc4ac26f74f54628ae37ad6437c7d3e8caba..d3b23d949ee2c7674c3918d39e8b71d76eefcfec 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -132,6 +132,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":custom_loss_head", ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index a3df272e6924792128fc38fd153b9527b58b486e..b314b4d74df882a421d9a2ecce2629a63d5c5248 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -41,7 +41,8 @@ def make_custom_export_strategy(name, convert_fn, feature_columns, export_input_fn, - use_core_columns=False): + use_core_columns=False, + feature_engineering_fn=None): """Makes custom exporter of GTFlow tree format. Args: @@ -52,6 +53,7 @@ def make_custom_export_strategy(name, export_input_fn: A function that takes no arguments and returns an `InputFnOps`. use_core_columns: A boolean, whether core feature columns were used. + feature_engineering_fn: Feature eng function to be called on the input. Returns: An `ExportStrategy`. @@ -59,9 +61,12 @@ def make_custom_export_strategy(name, base_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() + features = input_fn.features + if feature_engineering_fn is not None: + features, _ = feature_engineering_fn(features, labels=None) (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns, use_core_columns) + features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index ca73e4af2fbd0a383d02fa7111f59161701661df..358404cd946bbc56d2f7228be8fe4223749c850b 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn -from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 38d19976ef38a295a172e935f70bdae3c67f01e2..a178820841c4c8bcb7f5742babdb6d0f4825de31 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -26,7 +28,8 @@ from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import losses as core_losses - +from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head +from tensorflow.python.ops import array_ops # ================== Old estimator interface=================================== # The estimators below were designed for old feature columns and old estimator @@ -414,6 +417,108 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None, + num_quantiles=100): + """Initializes a GradientBoostedDecisionTreeQuantileRegressor instance. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + """ + + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _quantile_regression_head(quantile): + # Use quantile regression. + head = custom_loss_head.CustomLossHead( + loss_fn=functools.partial( + losses.per_example_quantile_regression_loss, quantile=quantile), + link_fn=array_ops.identity, + logit_dimension=label_dimension) + return head + + learner_config.num_classes = max(2, label_dimension) + + super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=model.model_builder, + params={ + 'head': _quantile_regression_head(quantiles[0]), + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'logits_modifier_function': logits_modifier_function, + 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. @@ -437,12 +542,42 @@ def core_multiclass_head( # pylint:disable=protected-access head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( - n_classes=n_classes, loss_fn=loss_fn, loss_reduction=loss_reduction) + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) # pylint:enable=protected-access return head_fn +# For quantile regression, use this head with Core..Estimator, or use +# Core..QuantileRegressor directly, +def core_quantile_regression_head( + quantiles, + label_dimension=1, + weight_column=None, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + """Core head for quantile regression problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_quantile_regression_loss( + labels=labels, + predictions=logits, + weights=weight_column, + quantile=quantiles) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._regression_head( + label_dimension=label_dimension, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) + # pylint:enable=protected-access + return head_fn + + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): """An estimator using gradient boosted decision trees. @@ -606,3 +741,104 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): super(CoreGradientBoostedDecisionTreeRanker, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class CoreGradientBoostedDecisionTreeQuantileRegressor( + core_estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False, + num_quantiles=100): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. + """ + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': + core_quantile_regression_head( + quantiles[0], label_dimension=label_dimension), + 'feature_columns': + feature_columns, + 'learner_config': + learner_config, + 'num_trees': + num_trees, + 'weight_column_name': + weight_column_name, + 'examples_per_layer': + examples_per_layer, + 'center_bias': + center_bias, + 'logits_modifier_function': + logits_modifier_function, + 'use_core_libs': + True, + 'output_leaf_index': + output_leaf_index, + 'override_global_step_value': + None, + 'num_quantiles': + num_quantiles, + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index c155128c0e4ccf928349ee6453baff4384222096..ee052ac60387d8f993e4942dd7dff39e191dd3a4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,8 +48,8 @@ def _multiclass_train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) } - label = constant_op.constant( - [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + label = constant_op.constant([[1], [0], [0], [2], [2], [0], [1]], + dtype=dtypes.int32) return features, label @@ -77,6 +78,59 @@ def _infer_ranking_train_input_fn(): return features, None +_QUANTILE_REGRESSION_SIZE = 1000 + + +def _quantile_regression_input_fns(two_dimension=False): + # The data generation is taken from + # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html + np.random.seed(1) + + def f(x): + """The function to predict.""" + return x * np.sin(x) + + def g(x): + """The function to predict.""" + return x * np.cos(x) + + # Training data. + x = np.atleast_2d(np.random.uniform(0, 10.0, + size=_QUANTILE_REGRESSION_SIZE)).T + x = x.astype(np.float32) + + # Labels. + if not two_dimension: + y = f(x).ravel() + else: + y = np.column_stack((f(x).ravel(), g(x).ravel())) + + # Add random noise. + dy = 1.5 + 1.0 * np.random.random(y.shape) + noise = np.random.normal(0, dy) + y += noise + y_original = y.astype(np.float32) + if not two_dimension: + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + train_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=None, + shuffle=True) + + # Test on the training data to make sure the predictions are calibrated. + test_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=1, + shuffle=False) + + return train_input_fn, test_input_fn, y_original + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): @@ -341,6 +395,130 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): for prediction_dict in result_iter: self.assertTrue("classes" in prediction_dict) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -489,8 +667,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): feature_columns = [ core_feature_column.weighted_categorical_column( - categorical_column=core_feature_column. - categorical_column_with_vocabulary_list( + categorical_column=core_feature_column + .categorical_column_with_vocabulary_list( key="word", vocabulary_list=["the", "cat", "dog"]), weight_feature_key="weight") ] @@ -509,8 +687,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): # Weights for the words are 5 - cat, 6- dog and 1 -the. features_dict["word"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], - values=constant_op.constant( - ["the", "cat", "dog", "the"], dtype=dtypes.string), + values=constant_op.constant(["the", "cat", "dog", "the"], + dtype=dtypes.string), dense_shape=[4, 3]) features_dict["weight"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], @@ -534,6 +712,132 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f45010ec26ed25127ca78b97f4d6fd7ebd6467ae..1fffbb5f660c681e1dde11a2aaf1d0f1cf79d1d0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -142,7 +142,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): name="StatsAccumulator/{}".format(self._name)) # Allocate both stats accumulator and quantile accumulator on the same # device so that we can build splits with fewer RPCs. - with ops.colocate_with(self._stats_accumulator.resource()): + with ops.colocate_with(self._stats_accumulator.resource_handle): self._quantile_accumulator = quantile_ops.QuantileAccumulator( init_stamp_token, epsilon=epsilon, @@ -268,8 +268,8 @@ class DenseSplitHandler(InequalitySplitHandler): handler = make_dense_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -447,8 +447,8 @@ class SparseSplitHandler(InequalitySplitHandler): handler = make_sparse_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 05ce0884ccfff53484fdc0c26e596e7fb6fcdfd6..356ae337685d580319da16a20bbab27ccaa73255 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -34,7 +34,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -62,7 +62,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2, 1], @@ -91,7 +91,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -123,7 +123,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -133,7 +133,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -164,7 +164,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -175,7 +175,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): deserialize = ( - accumulator.deserialize( + accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], @@ -223,7 +223,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -261,7 +261,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -299,7 +299,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -336,7 +336,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -349,7 +349,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -386,7 +386,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -399,7 +399,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): 0.08]]]) with ops.control_dependencies([op1]): - deserialize = accumulator.deserialize( + deserialize = accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 25b2c9e2fd72bd018717e8a87fce726f26bad968..fca22c71a83459cb290eaebcf107cf1c14c222b7 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import @@ -31,6 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -82,6 +85,44 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): tree_ensemble_config=restored_tensors[1]) +class TreeEnsembleVariable(tracking.TrackableResource): + """A Tree ensemble model.""" + + def __init__(self, stamp_token, tree_ensemble_config, name, container=None): + self._stamp_token = stamp_token + self._tree_ensemble_config = tree_ensemble_config + self._name = name + self._container = container + self._init_op = None + super(TreeEnsembleVariable, self).__init__() + + def create_resource(self): + return gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_ensemble_variable( + self.resource_handle, self._stamp_token, self._tree_ensemble_config) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_ensemble_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return { + "tree_ensemble_variable": + functools.partial( + TreeEnsembleVariableSavable, + tree_ensemble_handle=self.resource_handle, + create_op=self.initializer) + } + + def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, @@ -99,12 +140,11 @@ def tree_ensemble_variable(stamp_token, A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: - resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( - container, shared_name=name, name=name) - create_op = gen_model_ops.create_tree_ensemble_variable( - resource_handle, stamp_token, tree_ensemble_config) - is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( - resource_handle) + tree_ensemble_var = TreeEnsembleVariable(stamp_token, tree_ensemble_config, + name, container) + resource_handle = tree_ensemble_var.resource_handle + create_op = tree_ensemble_var.initializer + is_initialized_op = tree_ensemble_var.is_initialized() # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 19b6b3296db394b07f57a25dbde187eb9195af38..0c319cc9bd1f720eb404a9da05227c5807ec874f 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,59 +33,20 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): - """A resource that allows distributed quantile computation.""" - - def __init__(self, - init_stamp_token, - epsilon, - num_quantiles, - max_elements=None, - name=None, - container=None, - generate_quantiles=False): - """Creates a QuantileAccumulator object. - - Args: - init_stamp_token: The initial value for the stamp token. - epsilon: Error bound on the quantile computation. - num_quantiles: Number of quantiles to produce from the final summary. - max_elements: Maximum number of elements added to the accumulator. - name: the name to save the accumulator under. - container: An optional `string`. Defaults to `""` - generate_quantiles: Generate quantiles instead of approximate boundaries. - If true, exactly `num_quantiles` will be produced in the final summary. - """ - self._epsilon = epsilon - self._generate_quantiles = generate_quantiles +class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for QuantileAccumulator.""" - name = _PATTERN.sub("", name) - with ops.name_scope(name, "QuantileAccumulator") as name: - self._quantile_accumulator_handle = ( - gen_quantile_ops.quantile_stream_resource_handle_op( - container=container, shared_name=name, name=name)) - self._create_op = gen_quantile_ops.create_quantile_accumulator( - self._quantile_accumulator_handle, - init_stamp_token, - epsilon=epsilon, - max_elements=max_elements, - num_quantiles=num_quantiles, - generate_quantiles=generate_quantiles) - is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( - self._quantile_accumulator_handle) - resources.register_resource(self._quantile_accumulator_handle, - self._create_op, is_initialized_op) - self._make_savable(name) - - def _make_savable(self, name): + def __init__(self, resource_handle, create_op, name): + self._resource_handle = resource_handle + self._create_op = create_op stamp_token, state, are_buckets_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_serialize( - self._quantile_accumulator_handle)) + gen_quantile_ops.quantile_accumulator_serialize(resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful in quantile accumulator. slice_spec = "" @@ -96,9 +57,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): specs += [make_save_spec(state, "_state")] specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] specs += [make_save_spec(buckets, "buckets")] - super(QuantileAccumulator, - self).__init__(self._quantile_accumulator_handle, specs, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, + specs, name) def restore(self, restored_tensors, unused_restored_shapes): """Restores the associated quantile accumulator from 'restored_tensors'. @@ -119,24 +79,94 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): buckets = restored_tensors[3] with ops.control_dependencies([self._create_op]): return gen_quantile_ops.quantile_accumulator_deserialize( - self._quantile_accumulator_handle, + self._resource_handle, stamp_token=stamp_token, stream_state=state, are_buckets_ready=are_buckets_ready, buckets=buckets) + +class QuantileAccumulator(tracking.TrackableResource): + """A resource that allows distributed quantile computation.""" + + def __init__(self, + init_stamp_token, + epsilon, + num_quantiles, + max_elements=None, + name=None, + container=None, + generate_quantiles=False): + """Creates a QuantileAccumulator object. + + Args: + init_stamp_token: The initial value for the stamp token. + epsilon: Error bound on the quantile computation. + num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. + name: the name to save the accumulator under. + container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. + """ + self._init_stamp_token = init_stamp_token + self._epsilon = epsilon + self._num_quantiles = num_quantiles + self._max_elements = max_elements + self._container = container + self._generate_quantiles = generate_quantiles + super(QuantileAccumulator, self).__init__() + + name = _PATTERN.sub("", name) + with ops.name_scope(name, "QuantileAccumulator") as name: + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self._init_op, + is_initialized_op) + self._saveable = QuantileAccumulatorSaveable(self.resource_handle, + self._init_op, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + def create_resource(self): + return gen_quantile_ops.quantile_stream_resource_handle_op( + container=self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_quantile_ops.create_quantile_accumulator( + self.resource_handle, + self._init_stamp_token, + epsilon=self._epsilon, + max_elements=self._max_elements, + num_quantiles=self._num_quantiles, + generate_quantiles=self._generate_quantiles) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_quantile_ops.quantile_accumulator_is_initialized( + self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return {"quantile_accumulator", self.saveable} + def get_buckets(self, stamp_token): """Returns quantile buckets created during previous flush.""" are_buckets_ready, buckets = ( gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token)) return are_buckets_ready[0], buckets[0] def schedule_get_buckets(self): """Returns a scheduled read of buckets created during previous flush.""" return batch_ops_utils.ScheduledStampedResourceOp( - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, op=gen_quantile_ops.quantile_accumulator_get_buckets) def _make_summary(self, column, example_weights): @@ -161,14 +191,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Adds quantile summary to its stream in resource.""" summary = self._make_summary(column, example_weights) return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) def add_prebuilt_summary(self, stamp_token, summary): """Adds quantile summary to its stream in resource.""" return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) @@ -177,7 +207,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summary = self._make_summary(column, example_weights) return batch_ops_utils.ScheduledStampedResourceOp( op=gen_quantile_ops.quantile_accumulator_add_summaries, - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, summaries=summary) def flush(self, stamp_token, next_stamp_token): @@ -190,17 +220,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) def flush_summary(self, stamp_token, next_stamp_token): """Finalizes quantile summary stream and resets it for next iteration.""" result = gen_quantile_ops.quantile_accumulator_flush_summary( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result - - def resource(self): - return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 2e94e353f325f06eed2d290d3a7a461861820c39..ad1191d41236e71008bff8c8a7fbd42c16e3f9c5 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,12 +26,83 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): +class StatsAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for StatsAccumulator.""" + + def __init__(self, resource_handle, create_op, is_scalar, name): + self._create_op = create_op + self._resource_handle = resource_handle + self._is_scalar = is_scalar + slice_spec = "" + saver_name = self._resource_handle.name + (stamp_token, num_updates, partition_ids, feature_ids, gradients, + hessians) = self.serialize() + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + saver_name + "_stamp"), + saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, + saver_name + "_num_updates"), + saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, + saver_name + "_partition_ids"), + saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, + saver_name + "_feature_ids"), + saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, + saver_name + "_gradients"), + saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, + saver_name + "hessians"), + ] + super(StatsAccumulatorSaveable, self).__init__(self._resource_handle, specs, + name) + + def serialize(self): + """Serializes the stats accumulator state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( + self._resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( + self._resource_handle) + + def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, + gradients, hessians): + """Resets the stats accumulator with the serialized state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return self.deserialize( + stamp_token=restored_tensors[0], + num_updates=restored_tensors[1], + partition_ids=restored_tensors[2], + feature_ids=restored_tensors[3], + gradients=restored_tensors[4], + hessians=restored_tensors[5]) + + +class StatsAccumulator(tracking.TrackableResource): """A resource that allows to accumulate gradients and hessians. For consistency guarantees, we use read and write stamp tokens. @@ -58,58 +129,69 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ + self._stamp_token = stamp_token + self._gradient_shape = gradient_shape + self._hessian_shape = hessian_shape + self._container = container + + if (gradient_shape == tensor_shape.scalar() and + hessian_shape == tensor_shape.scalar()): + self._is_scalar = True + else: + self._is_scalar = False + if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: - # Both values are scalars. - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): - self._is_scalar = True - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_scalar_resource_handle_op( - container, name, name=name)) - - create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( - self._resource_handle, stamp_token) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( - self._resource_handle)) - else: - self._is_scalar = False - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_tensor_resource_handle_op( - container, name, name=name)) - create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( - self._resource_handle, stamp_token, gradient_shape.as_list(), - hessian_shape.as_list()) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( - self._resource_handle)) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self.initializer, + is_initialized_op) + self._saveable = StatsAccumulatorSaveable( + self.resource_handle, self.initializer, self._is_scalar, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - self._create_op = create_op - slice_spec = "" - saver_name = self._resource_handle.name - (stamp_token, num_updates, partition_ids, feature_ids, gradients, - hessians) = self.serialize() - specs = [ - saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, - saver_name + "_stamp"), - saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, - saver_name + "_num_updates"), - saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, - saver_name + "_partition_ids"), - saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, - saver_name + "_feature_ids"), - saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, - saver_name + "_gradients"), - saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, - saver_name + "hessians"), - ] + def create_resource(self): + if self._is_scalar: + return ( + gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( + self._container, self._name, name=self._name)) + else: + return ( + gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( + self._container, self._name, name=self._name)) - super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) - resources.register_resource(self._resource_handle, create_op, - is_initialized_op) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + def initialize(self): + if self._is_scalar: + return gen_stats_accumulator_ops.create_stats_accumulator_scalar( + self.resource_handle, self._stamp_token) + else: + return gen_stats_accumulator_ops.create_stats_accumulator_tensor( + self.resource_handle, self._stamp_token, + self._gradient_shape.as_list(), self._hessian_shape.as_list()) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( + self.resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( + self.resource_handle) + + @property + def saveable(self): + return self._saveable + + def _gather_saveables_for_checkpoint(self): + return {"stats_accumulator", self.saveable} def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): """Updates the stats accumulator.""" @@ -117,11 +199,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): partition_ids, feature_ids, gradients, hessians)) if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) def schedule_add(self, partition_ids, feature_ids, gradients, hessians): @@ -131,7 +213,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): if self._is_scalar: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_scalar_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -139,7 +221,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): else: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_tensor_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -153,55 +235,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): return gen_stats_accumulator_ops.stats_accumulator_tensor_make_summary( partition_ids, feature_ids, gradients, hessians) - def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, - gradients, hessians): - """Resets the stats accumulator with the serialized state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - def flush(self, stamp_token, next_stamp_token): """Flushes the stats accumulator.""" if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - self._resource_handle, stamp_token, next_stamp_token) + self.resource_handle, stamp_token, next_stamp_token) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - self._resource_handle, stamp_token, next_stamp_token) - - def serialize(self): - """Serializes the stats accumulator state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( - self._resource_handle) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( - self._resource_handle) - - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated tree ensemble from 'restored_tensors'. - - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. Not meaningful for trees. - - Returns: - The operation that restores the state of the tree ensemble variable. - """ - with ops.control_dependencies([self._create_op]): - return self.deserialize( - stamp_token=restored_tensors[0], - num_updates=restored_tensors[1], - partition_ids=restored_tensors[2], - feature_ids=restored_tensors[3], - gradients=restored_tensors[4], - hessians=restored_tensors[5]) - - def resource(self): - return self._resource_handle + self.resource_handle, stamp_token, next_stamp_token) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 1cf61a10ba25f206333bb78b7944e366bcd19b92..85020c5df293598e79de0e964f55af5231aa3622 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -992,7 +992,7 @@ class GradientBoostedDecisionTreeModel(object): # Get accumulated steps and examples for the current layer. _, _, _, _, acc_examples, acc_steps = ( - steps_accumulator.serialize()) + steps_accumulator.saveable.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) ensemble_update_ops.append( @@ -1257,13 +1257,12 @@ class GradientBoostedDecisionTreeModel(object): def _get_replica_device_setter(self, worker_device): """Creates a replica device setter.""" ps_tasks = self._num_ps_replicas - ps_ops = [ - "Variable", - "VariableV2", + ps_ops = list(device_setter.STANDARD_PS_OPS) + ps_ops.extend([ "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - ] + ]) ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( worker_device=worker_device, diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index b5ebaf1999519f65110e8164fa20bace5ecc3ef6..7a99dccdd1066354ee50dad0622a6fbda9c860ff 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -48,6 +48,47 @@ def per_example_logistic_loss(labels, weights, predictions): labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() +# MUST USE WITH HESSIAN REGULARIZATION, +# This loss can have zero hessian, so it must be used with l2 or min_node_weight +# regularization. +# An example config is +# learner_config.constraints.min_node_weight = 1 / num_examples_per_layer +# learner_config.regularization.l2 = 1.0 / num_examples_per_layer +# TODO(nponomareva): make it multidimensional so we can estimate several +# quantiles at once. +def per_example_quantile_regression_loss(labels, weights, predictions, + quantile): + """Smoothed loss for quantile regression. + + The standard quantile regression loss is quantile*(y-y') when y>y' and + (quantile-1)*(y-y') otherwise, y' is a prediction, y is a label. The impl + below is this loss but squared in the region where the loss value < 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + quantile: The quantile to use. + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example quantile loss. + update_op: An update operation to update the loss's internal state. + """ + labels = math_ops.to_float(labels) + error = labels - predictions + square_loss_right = array_ops.where(error * quantile < 1.0, + math_ops.square(quantile * error), + quantile * error) + square_loss_left = array_ops.where(error * (quantile - 1) < 1, + math_ops.square((quantile - 1) * error), + (quantile - 1) * error) + + unweighted_loss = array_ops.where(error > 0, square_loss_right, + square_loss_left) + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() # This is classical form of Maximum entropy loss, that is twice differentiable # (sparse_softmax_cross_entropy which is what we go for is not twice diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 5ecd4f341831ce8d6f8eb04a763280c177ffe275..7774ac0e122a532e1e0280f185ead3022a0b89d6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -25,6 +25,13 @@ import six from tensorflow.python.training.server_lib import ClusterSpec +def format_master_url(master, rpc_layer=None): + if rpc_layer: + return '%s://%s' % (rpc_layer, master) + else: + return master + + @six.add_metaclass(abc.ABCMeta) class ClusterResolver(object): """Abstract class for all implementations of ClusterResolvers. @@ -37,6 +44,17 @@ class ClusterResolver(object): automatically discover and resolve IP addresses for various TensorFlow workers. This will eventually allow us to automatically recover from underlying machine failures and scale TensorFlow worker clusters up and down. + + Note to Implementors: In addition to these abstract methods, you must also + implement the task_type, task_index, and rpc_layer attributes. You may choose + to implement them either as properties with getters or setters or directly + set the attributes. + + - task_type is the name of the server's current named job (e.g. 'worker', + 'ps' in a distributed parameterized training job). + - task_index is the ordinal index of the server within the task type. + - rpc_layer is the protocol used by TensorFlow to communicate with other + TensorFlow servers in a distributed environment. """ @abc.abstractmethod @@ -53,16 +71,16 @@ class ClusterResolver(object): management system every time this function is invoked and reconstructing a cluster_spec, rather than attempting to cache anything. """ - raise NotImplementedError( - 'cluster_spec is not implemented for {}.'.format(self)) + raise NotImplementedError() @abc.abstractmethod - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Retrieves the name or URL of the session master. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. @@ -71,16 +89,44 @@ class ClusterResolver(object): returned is up-to-date at the time to calling this function. This usually means retrieving the master every time this function is invoked. """ - raise NotImplementedError('master is not implemented for {}.'.format(self)) + raise NotImplementedError() + + @abc.abstractmethod + def num_accelerators_per_worker(self, session_config=None): + """Returns the number of accelerator cores per worker. + + This returns the number of accelerator cores (such as GPUs and TPUs) + available per worker. If workers only has CPU cores available, then this + should return 0. This method will query the master for this information + if it is not otherwise known. + + Args: + session_config: (Optional) Configuration for starting a new session to + query how many accelerator cores it has. + """ + raise NotImplementedError() + + @abc.abstractproperty + def environment(self): + """Returns the current environment which TensorFlow is running in.""" + raise NotImplementedError() class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec, master=''): + def __init__(self, cluster_spec, master='', task_type=None, task_index=None, + environment='', num_accelerators_per_worker=0, + rpc_layer=None): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() + self._task_type = task_type + self._task_index = task_index + self._environment = environment + self._num_accelerators_per_worker = num_accelerators_per_worker + self._rpc_layer = rpc_layer + if not isinstance(cluster_spec, ClusterSpec): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec @@ -93,12 +139,13 @@ class SimpleClusterResolver(ClusterResolver): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC used by distributed TensorFlow. Returns: The name or URL of the session master. @@ -106,10 +153,52 @@ class SimpleClusterResolver(ClusterResolver): If a task_type and task_index is given, this will override the `master` string passed into the initialization function. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + else: + master = self._master + + return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer) + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._environment + + def num_accelerators_per_worker(self, session_config=None): + """Returns the number of accelerator cores per worker. + + Args: + session_config: Unused. The SimpleClusterResolver does not do automatic + detection of accelerators, so a TensorFlow session will never be + created, and thus a `session_config` is never necessary here, and will + be ignored. + """ + del session_config + return self._num_accelerators_per_worker - return self._master + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer class UnionClusterResolver(ClusterResolver): @@ -119,13 +208,22 @@ class UnionClusterResolver(ClusterResolver): merges the underlying ClusterResolvers, and returns one unified ClusterSpec when cluster_spec is called. The details of the merge function is documented in the cluster_spec function. + + For additional Cluster Resolver properties such as task type, task index, + rpc layer, environment, etc..., we will return the value from the first + ClusterResolver in the union. """ - def __init__(self, *args): + def __init__(self, *args, **kwargs): """Initializes a UnionClusterResolver with other ClusterResolvers. Args: *args: `ClusterResolver` objects to be unionized. + **kwargs: + rpc_layer - (Optional) Override value for the RPC layer used by + TensorFlow. + task_type - (Optional) Override value for the current task type. + task_index - (Optional) Override value for the current task index. Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. @@ -133,6 +231,13 @@ class UnionClusterResolver(ClusterResolver): """ super(UnionClusterResolver, self).__init__() + self._rpc_layer = kwargs.pop('rpc_layer', None) + self._task_type = kwargs.pop('task_type', None) + self._task_index = kwargs.pop('task_index', None) + + if kwargs: + raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) + if not args: raise ValueError('At least one ClusterResolver is required.') @@ -216,7 +321,7 @@ class UnionClusterResolver(ClusterResolver): return ClusterSpec(merged_cluster) - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. This usually returns the master from the first ClusterResolver passed in, @@ -225,11 +330,45 @@ class UnionClusterResolver(ClusterResolver): Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + return format_master_url(master, rpc_layer or self._rpc_layer) + + return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) + + @property + def task_type(self): + return self._task_type or self._cluster_resolvers[0].task_type + + @property + def task_index(self): + return self._task_index or self._cluster_resolvers[0].task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._cluster_resolvers[0].environment + + def num_accelerators_per_worker(self, session_config=None): + return self._cluster_resolvers[0].num_accelerators_per_worker( + session_config) + + @property + def rpc_layer(self): + return self._rpc_layer or self._cluster_resolvers[0].rpc_layer - return self._cluster_resolvers[0].master() + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index c004b2e2d3bc6552a3ab10997ed44f24e611735a..b94c9612b5bd4d92e84319f22932ce5599ba4b36 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -57,6 +57,62 @@ class UnionClusterResolverTest(test.TestCase): actual_cluster_spec = union_resolver.cluster_spec() self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testInitSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + self.assertEqual(simple_resolver.task_type, "ps") + self.assertEqual(simple_resolver.task_index, 1) + self.assertEqual(simple_resolver.environment, "cloud") + self.assertEqual(simple_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(simple_resolver.rpc_layer, "grpc") + + def testOverrideSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + simple_resolver.task_type = "worker" + simple_resolver.task_index = 2 + simple_resolver.rpc_layer = "http" + + self.assertEqual(simple_resolver.task_type, "worker") + self.assertEqual(simple_resolver.task_index, 2) + self.assertEqual(simple_resolver.rpc_layer, "http") + + def testSimpleOverrideMasterWithTaskIndexZero(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker0:2222") + + def testSimpleOverrideMasterWithRpcLayer(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 2, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker2:2222") + def testSimpleOverrideMaster(self): base_cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], @@ -65,7 +121,42 @@ class UnionClusterResolverTest(test.TestCase): simple_resolver = SimpleClusterResolver(base_cluster_spec) actual_master = simple_resolver.master("worker", 2) - self.assertEquals(actual_master, "worker2:2222") + self.assertEqual(actual_master, "worker2:2222") + + def testUnionClusterResolverGetProperties(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + resolver1 = SimpleClusterResolver(cluster_spec_1, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + cluster_spec_2 = server_lib.ClusterSpec({ + "ps": ["ps2:2222", "ps3:2222"], + "worker": ["worker3:2222", "worker4:2222", "worker5:2222"] + }) + resolver2 = SimpleClusterResolver(cluster_spec_2, task_type="worker", + task_index=2, environment="local", + num_accelerators_per_worker=16, + rpc_layer="http") + + union_resolver = UnionClusterResolver(resolver1, resolver2) + + self.assertEqual(union_resolver.task_type, "ps") + self.assertEqual(union_resolver.task_index, 1) + self.assertEqual(union_resolver.environment, "cloud") + self.assertEqual(union_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(union_resolver.rpc_layer, "grpc") + + union_resolver.task_type = "worker" + union_resolver.task_index = 2 + union_resolver.rpc_layer = "http" + + self.assertEqual(union_resolver.task_type, "worker") + self.assertEqual(union_resolver.task_index, 2) + self.assertEqual(union_resolver.rpc_layer, "http") def testTwoNonOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ @@ -116,10 +207,13 @@ class UnionClusterResolverTest(test.TestCase): union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) unspecified_master = union_cluster.master() - self.assertEquals(unspecified_master, "") + self.assertEqual(unspecified_master, "") specified_master = union_cluster.master("worker", 1) - self.assertEquals(specified_master, "worker1:2222") + self.assertEqual(specified_master, "worker1:2222") + + rpc_master = union_cluster.master("worker", 1, rpc_layer="grpc") + self.assertEqual(rpc_master, "grpc://worker1:2222") def testOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 5083e4d10ba6ee2e1be8f373c099556b422ef5aa..195b68959b6d21ef674438a4a23a4dd07f45faa7 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -30,6 +30,10 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +def _format_master_url(master, rpc_layer=None): + return '%s://%s' % (rpc_layer, master) if rpc_layer else master + + class GceClusterResolver(ClusterResolver): """Cluster Resolver for Google Compute Engine. @@ -45,7 +49,10 @@ class GceClusterResolver(ClusterResolver): zone, instance_group, port, - job_name='worker', + task_type='worker', + task_index=0, + rpc_layer='grpc', + num_accelerators_per_worker=0, credentials='default', service=None): """Creates a new GceClusterResolver object. @@ -55,13 +62,22 @@ class GceClusterResolver(ClusterResolver): each instance in the instance group. Args: - project: Name of the GCE project - zone: Zone of the GCE instance group - instance_group: Name of the GCE instance group + project: Name of the GCE project. + zone: Zone of the GCE instance group. + instance_group: Name of the GCE instance group. port: Port of the listening TensorFlow server (default: 8470) - job_name: Name of the TensorFlow job this set of instances belongs to + task_type: Name of the TensorFlow job this GCE instance group of VM + instances belong to. + task_index: The task index for this particular VM, within the GCE + instance group. In particular, every single instance should be assigned + a unique ordinal index within an instance group manually so that they + can be distinguished from each other. + rpc_layer: The RPC layer TensorFlow should use to communicate across + instances. + num_accelerators_per_worker: Number of accelerators (GPUs) present per + instance. credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default() + GoogleCredentials.get_application_default(). service: The GCE API object returned by the googleapiclient.discovery function. (Default: discovery.build('compute', 'v1')). If you specify a custom service object, then the credentials parameter will be ignored. @@ -72,7 +88,9 @@ class GceClusterResolver(ClusterResolver): self._project = project self._zone = zone self._instance_group = instance_group - self._job_name = job_name + self._task_type = task_type + self._task_index = task_index + self._rpc_layer = rpc_layer self._port = port self._credentials = credentials @@ -133,10 +151,58 @@ class GceClusterResolver(ClusterResolver): previous_response=response) worker_list.sort() - return ClusterSpec({self._job_name: worker_list}) + return ClusterSpec({self._task_type: worker_list}) - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + def master(self, task_type=None, task_index=None, rpc_layer=None): + task_type = task_type if task_type is not None else self._task_type + task_index = task_index if task_index is not None else self._task_index + + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + if rpc_layer or self._rpc_layer: + return '%s://%s' % (rpc_layer or self._rpc_layer, master) + else: + return master return '' + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + raise RuntimeError( + 'You cannot reset the task_type of the GceClusterResolver after it has ' + 'been created.') + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in. + + For users in the GCE environment, the environment property is always an + empty string, and Google users will not use this ClusterResolver for running + on internal systems. + """ + return '' + + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer + + def num_accelerators_per_worker(self, session_config=None): + del session_config # Unused, since this is set manually in __init__. + return self._num_accelerators_per_worker + diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py index 87b8303122498992dd24ae06824f7f769357d8f8..c691552e86025896e23891a3e8f7da5ed2f9da31 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py @@ -135,12 +135,86 @@ class GceClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testMasterRetrieval(self): + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_index=0, + port=8470, + credentials=None, + service=self.standard_mock_service_client()) + self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470') + + def testMasterRetrievalWithCustomTasks(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual( + gce_cluster_resolver.master('worker', 2, 'test'), + 'test://10.3.4.5:8470') + + def testOverrideParameters(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='testworker', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + gce_cluster_resolver.task_index = 1 + gce_cluster_resolver.rpc_layer = 'test' + + self.assertEqual(gce_cluster_resolver.task_type, 'testworker') + self.assertEqual(gce_cluster_resolver.task_index, 1) + self.assertEqual(gce_cluster_resolver.rpc_layer, 'test') + self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470') + + def testOverrideParametersWithZeroOrEmpty(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='', + task_index=1, + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual(gce_cluster_resolver.master( + task_type='', task_index=0), 'grpc://10.1.2.3:8470') + def testCustomJobNameAndPortRetrieval(self): gce_cluster_resolver = GceClusterResolver( project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='custom', + task_type='custom', port=2222, credentials=None, service=self.standard_mock_service_client()) @@ -196,7 +270,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker1_name_to_ip)) @@ -205,7 +279,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker2_name_to_ip)) @@ -214,7 +288,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='ps', + task_type='ps', port=2222, credentials=None, service=self.gen_standard_mock_service_client(ps_name_to_ip)) diff --git a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py index ddae64839f01b4f67fe4c0c0bc00199bb2e037aa..eab1359a5bdf0e15d630e209964fa46dce9b2d42 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url +from tensorflow.python.client import device_lib from tensorflow.python.training import server_lib _KUBERNETES_API_CLIENT_INSTALLED = True @@ -41,6 +43,7 @@ class KubernetesClusterResolver(ClusterResolver): def __init__(self, job_to_label_mapping=None, tf_server_port=8470, + rpc_layer='grpc', override_client=None): """Initializes a new KubernetesClusterResolver. @@ -58,6 +61,8 @@ class KubernetesClusterResolver(ClusterResolver): 'ps': ['job-name=ps-1', 'job-name=ps-2']} ``` tf_server_port: The port the TensorFlow server is listening on. + rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate + between tasks in Kubernetes. Defaults to 'grpc'. override_client: The Kubernetes client (usually automatically retrieved using `from kubernetes import client as k8sclient`). If you pass this in, you are responsible for setting Kubernetes credentials manually. @@ -65,6 +70,7 @@ class KubernetesClusterResolver(ClusterResolver): Raises: ImportError: If the Kubernetes Python client is not installed and no `override_client` is passed in. + RuntimeError: If autoresolve_task is not a boolean or a callable. """ if _KUBERNETES_API_CLIENT_INSTALLED: k8sconfig.load_kube_config() @@ -82,16 +88,37 @@ class KubernetesClusterResolver(ClusterResolver): self._tf_server_port = tf_server_port self._override_client = override_client - def master(self): - # TODO(frankchn): Figure out a standard way to pass in the current task type - # and task id via Kubernetes. - pass + self.task_type = None + self.task_index = None + self.rpc_layer = rpc_layer - def get_master(self): - return self.master() + def master(self, task_type=None, task_index=None, rpc_layer=None): + """Returns the master address to use when creating a session. - def get_job_name(self): - return self._job_name + You must have set the task_type and task_index object properties before + calling this function, or pass in the `task_type` and `task_index` + parameters when using this function. If you do both, the function parameters + will override the object properties. + + Args: + task_type: (Optional) The type of the TensorFlow task of the master. + task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. + + Returns: + The name or URL of the session master. + """ + if task_type is not None and task_index is not None: + return format_master_url( + self.cluster_spec().task_address(task_type, task_index), + rpc_layer or self.rpc_layer) + + if self.task_type is not None and self.task_index is not None: + return format_master_url( + self.cluster_spec().task_address(self.task_type, self.task_index), + rpc_layer or self.rpc_layer) + + return '' def cluster_spec(self): """Returns a ClusterSpec object based on the latest info from Kubernetes. @@ -130,3 +157,17 @@ class KubernetesClusterResolver(ClusterResolver): cluster_map[tf_job] = all_pods return server_lib.ClusterSpec(cluster_map) + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in. + + For users in the Cloud environment, the environment property is always an + empty string, and Google users will not use this ClusterResolver for running + on internal systems. + """ + return '' + + def num_accelerators_per_worker(self, session_config=None): + local_devices = device_lib.list_local_devices(session_config) + return len([d for d in local_devices if d.device_type == 'GPU']) diff --git a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver_test.py index fbb26e803d73c96decf57a040a05694a434500f2..c63a98af6c24efa22c49c9ba38abd243c17d478e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver_test.py @@ -109,6 +109,23 @@ class KubernetesClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + def testGetMasterWithOverrideParameters(self): + ret = _create_pod_list( + ('worker-0', 'Running', '10.1.2.3'), + ('worker-1', 'Running', '10.1.2.4'), + ('worker-2', 'Running', '10.1.2.5')) + + cluster_resolver = KubernetesClusterResolver( + override_client=_mock_kubernetes_client( + {'job-name=tensorflow': ret})) + cluster_resolver.task_type = 'worker' + cluster_resolver.task_index = 0 + self.assertEqual(cluster_resolver.task_type, 'worker') + self.assertEqual(cluster_resolver.task_index, 0) + self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470') + self.assertEqual(cluster_resolver.master('worker', 2), + 'grpc://10.1.2.5:8470') + def testNonRunningPod(self): ret = _create_pod_list(('tensorflow-abc123', 'Failed', '10.1.2.3'),) diff --git a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py index dabe2fe1d39db14c60e5437d636144f18c384cf1..f590ecead96565672af30c2f3702f1a21f4317be 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py @@ -53,7 +53,8 @@ class SlurmClusterResolver(ClusterResolver): gpus_per_node=1, gpus_per_task=1, tasks_per_node=None, - auto_set_gpu=True): + auto_set_gpu=True, + rpc_layer='grpc'): """Creates a new SlurmClusterResolver object. This takes in parameters and creates a SlurmClusterResolver object. It uses @@ -74,6 +75,8 @@ class SlurmClusterResolver(ClusterResolver): auto_set_gpu: Set the visible CUDA devices automatically while resolving the cluster by setting CUDA_VISIBLE_DEVICES environment variable. Defaults to True. + rpc_layer: (Optional) The protocol TensorFlow uses to communicate between + nodes. Defaults to 'grpc'. Returns: A ClusterResolver object which can be used with distributed TensorFlow. @@ -107,8 +110,9 @@ class SlurmClusterResolver(ClusterResolver): self._gpus_per_task = gpus_per_task self._auto_set_gpu = auto_set_gpu - self._job_name = None - self._task_index = None + self.task_type = None + self.task_index = None + self.rpc_layer = rpc_layer self._gpu_allocation = [] self._cluster_allocation = {} @@ -157,17 +161,15 @@ class SlurmClusterResolver(ClusterResolver): cluster_rank_offset_start = 0 cluster_rank_offset_end = 0 - for job_name, num_tasks in self._jobs.items(): + for task_type, num_tasks in self._jobs.items(): cluster_rank_offset_end = cluster_rank_offset_start + num_tasks - self._cluster_allocation[job_name] = \ - task_list[cluster_rank_offset_start:cluster_rank_offset_end] + self._cluster_allocation[task_type] = ( + task_list[cluster_rank_offset_start:cluster_rank_offset_end]) - if self._rank >= cluster_rank_offset_start and \ - self._rank < cluster_rank_offset_end: - - self._job_name = job_name - self._task_index = self._rank - cluster_rank_offset_start + if cluster_rank_offset_start <= self._rank < cluster_rank_offset_end: + self.task_type = task_type + self.task_index = self._rank - cluster_rank_offset_start cluster_rank_offset_start = cluster_rank_offset_end @@ -188,9 +190,37 @@ class SlurmClusterResolver(ClusterResolver): A string specifying job name the process belongs to and an integner specifying the task index the process belongs to in that job. """ - return self._job_name, self._task_index + return self.task_type, self.task_index + + def master(self, task_type=None, task_index=None, rpc_layer=None): + """Returns the master string for connecting to a TensorFlow master. + + Args: + task_type: (Optional) Overrides the default auto-selected task type. + task_index: (Optional) Overrides the default auto-slected task index. + rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses + to communicate across nodes. + + Returns: + A connection string for connecting to a TensorFlow master. + """ + task_type = task_type if task_type is not None else self.task_type + task_index = task_index if task_index is not None else self.task_index + rpc_layer = rpc_layer or self.rpc_layer + master = self.cluster_spec().task_address(task_type, task_index) + + return '%s://%s' % (rpc_layer, master) if rpc_layer else master + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in. + + For users in the Slurm environment, the environment property is always an + empty string, and Google users will not use this ClusterResolver for running + on internal systems. + """ + return '' - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) - return self._cluster_allocation[str(self._job_name)][self._task_index] + def num_accelerators_per_worker(self, session_config=None): + del session_config # Unused, since this is set in __init__ manually. + return self._gpus_per_node diff --git a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver_test.py index 9aa7df745eb8e1c444011485687b213d87c37da5..7c76e133fe4762f3ea072ef4784cba00996b95cc 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver_test.py @@ -67,6 +67,31 @@ class SlurmClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.dict(os.environ, {'SLURM_PROCID': '0', 'SLURM_NTASKS': '3'}) + @mock.patch.object(SlurmClusterResolver, '_resolve_hostnames', + mock_resolve_hostnames_output) + def testSimpleMasterRetrieval(self): + slurm_cluster_resolver = SlurmClusterResolver( + jobs={ + 'ps': 1, + 'worker': 2 + }, + port_base=8888, + tasks_per_node=1, + gpus_per_node=1, + gpus_per_task=1, + auto_set_gpu=False) + + slurm_cluster_resolver.task_type = 'worker' + slurm_cluster_resolver.task_index = 1 + self.assertEqual(slurm_cluster_resolver.master(), 'grpc://t02n43:8888') + + slurm_cluster_resolver.rpc_layer = 'ab' + self.assertEqual(slurm_cluster_resolver.master('ps', 0), 'ab://t02n13:8888') + self.assertEqual( + slurm_cluster_resolver.master('ps', 0, rpc_layer='test'), + 'test://t02n13:8888') + @mock.patch.dict(os.environ, { 'SLURM_PROCID': '0', 'SLURM_NTASKS': '3', diff --git a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py index 7bbd189d03d9c96914d11948941916739f10d18f..95aad0de1378dbee47ba24ff903da31fdb18a1af 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py @@ -27,13 +27,98 @@ from tensorflow.python.training.server_lib import ClusterSpec _TF_CONFIG_ENV = 'TF_CONFIG' _SESSION_MASTER_KEY = 'session_master' +_RPC_LAYER_KEY = 'rpc_layer' +_TASK_KEY = 'task' + + +def format_master_url(master, rpc_layer=None): + if rpc_layer: + return '%s://%s' % (rpc_layer, master) + else: + return master + + +def _load_tf_config(): + return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) + + +def _get_value_in_tfconfig(key, default=None): + tf_config = _load_tf_config() + return tf_config[key] if key in tf_config else default class TFConfigClusterResolver(ClusterResolver): """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar.""" - def _load_tf_config(self): - return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) + def __init__(self, + task_type=None, + task_index=None, + rpc_layer=None, + environment=None, + num_accelerators_per_worker=0): + """Creates a new TFConfigClusterResolver. + + Args: + task_type: (String, optional) Overrides the task type specified in the + TF_CONFIG environment variable. + task_index: (Integer, optional) Overrides the task index specified in the + TF_CONFIG environment variable. + rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses. + environment: (String, optional) Overrides the environment TensorFlow + operates in. + num_accelerators_per_worker: (Integer, optional) Specifies the number of + accelerators (e.g. GPUs, TPUs, others) that each node has. + """ + + self._task_type = task_type + self._task_index = task_index + self._rpc_layer = rpc_layer + self._environment = environment + self._num_accelerators_per_worker = num_accelerators_per_worker + + @property + def task_type(self): + if self._task_type is None: + task_info = _get_value_in_tfconfig(_TASK_KEY, {}) + return task_info['type'] if 'type' in task_info else None + else: + return self._task_type + + @property + def task_index(self): + if self._task_type is None: + task_info = _get_value_in_tfconfig(_TASK_KEY, {}) + return task_info['index'] if 'index' in task_info else None + else: + return self._task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._environment + + @property + def rpc_layer(self): + if self._rpc_layer is None: + return _get_value_in_tfconfig(_RPC_LAYER_KEY) + else: + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer + + def num_accelerators_per_worker(self, session_config=None): + # TODO(frankchn): Connect to server (w/ session_config) in the future. + del session_config # Unused, we do not connect to another server here. + return self._num_accelerators_per_worker def cluster_spec(self): """Returns a ClusterSpec based on the TF_CONFIG environment variable. @@ -41,12 +126,12 @@ class TFConfigClusterResolver(ClusterResolver): Returns: A ClusterSpec with information from the TF_CONFIG environment variable. """ - tf_config = self._load_tf_config() + tf_config = _load_tf_config() if 'cluster' not in tf_config: return ClusterSpec({}) return ClusterSpec(tf_config['cluster']) - def master(self, task_type=None, task_index=0): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a TensorFlow session. Args: @@ -54,6 +139,8 @@ class TFConfigClusterResolver(ClusterResolver): master. task_index: (Integer, optional) Overrides and sets the task id of the master. + rpc_layer: (String, optional) Overrides and sets the protocol over which + TensorFlow nodes communicate with each other. Returns: The address of the master. @@ -64,14 +151,9 @@ class TFConfigClusterResolver(ClusterResolver): """ # If `session_master` is set, just use that. - tf_config = self._load_tf_config() - if _SESSION_MASTER_KEY in tf_config: - return tf_config[_SESSION_MASTER_KEY] - - if 'rpc_layer' in tf_config: - rpclayer = '%s://' % tf_config['rpc_layer'] - else: - rpclayer = '' + session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY) + if session_master is not None: + return session_master # Return an empty string if we are the only job in the ClusterSpec. cluster_spec = self.cluster_spec() @@ -82,11 +164,8 @@ class TFConfigClusterResolver(ClusterResolver): # We try to auto-detect the task type and id, but uses the user-supplied one # where available - if not task_type: - if 'task' not in tf_config: - raise RuntimeError('You must either specify a `task_type`, or your ' - 'TF_CONFIG must contain a `task` section.') - task_type = tf_config['task']['type'] - task_index = tf_config['task']['index'] - - return rpclayer + cluster_spec.task_address(task_type, task_index) + task_type = task_type if task_type is not None else self.task_type + task_index = task_index if task_index is not None else self.task_index + + return format_master_url(cluster_spec.task_address(task_type, task_index), + self.rpc_layer) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver_test.py index 468161d2aa49129f2ec960b1ccddf49c712f00a7..3db6d5447f5abab6936a2ab4b4a149715ec01394 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver_test.py @@ -133,6 +133,58 @@ class TFConfigClusterResolverTest(test.TestCase): cluster_resolver = TFConfigClusterResolver() self.assertEqual('grpc://ps0:2222', cluster_resolver.master()) + def testTaskTypeIndexRpcRead(self): + os.environ['TF_CONFIG'] = """ + { + "cluster": { + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }, + "rpc_layer": "grpc", + "task": { + "type": "ps", + "index": 0 + } + } + """ + + cluster_resolver = TFConfigClusterResolver() + self.assertEqual('ps', cluster_resolver.task_type) + self.assertEqual(0, cluster_resolver.task_index) + self.assertEqual('grpc', cluster_resolver.rpc_layer) + + def testParameterOverrides(self): + os.environ['TF_CONFIG'] = """ + { + "cluster": { + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }, + "rpc_layer": "grpc", + "task": { + "type": "ps", + "index": 1 + } + } + """ + + cluster_resolver = TFConfigClusterResolver(task_type='ps', task_index=0, + num_accelerators_per_worker=8) + + self.assertEqual('grpc://ps0:2222', cluster_resolver.master()) + self.assertEqual('ps', cluster_resolver.task_type) + self.assertEqual(0, cluster_resolver.task_index) + self.assertEqual(8, cluster_resolver.num_accelerators_per_worker()) + + cluster_resolver.task_type = 'worker' + cluster_resolver.task_index = 1 + cluster_resolver.rpc_layer = 'test' + + self.assertEqual('test://worker1:2222', cluster_resolver.master()) + self.assertEqual('worker', cluster_resolver.task_type) + self.assertEqual(1, cluster_resolver.task_index) + self.assertEqual('test', cluster_resolver.rpc_layer) + def testZeroItemsInClusterSpecMasterRead(self): os.environ['TF_CONFIG'] = """ {} diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c4ac9d0700194da558820aabc28bf1c0857591e2..d5537a4100ddad19d2a9131b971f3d604d58f8f2 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -24,6 +24,7 @@ from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver +from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import format_master_url from tensorflow.python.training import server_lib from tensorflow.python.util import compat @@ -50,6 +51,34 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _tpuService(self): + """Creates a new Cloud TPU API object. + + This works around an issue where the underlying HTTP connection sometimes + times out when the script has been running for too long. Other methods in + this object calls this method to get a new API object whenever they need + to communicate with the Cloud API. + + Returns: + A Google Cloud TPU API object. + """ + if self._service: + return self._service + + credentials = self._credentials + if credentials is None or credentials == 'default': + credentials = GoogleCredentials.get_application_default() + + if self._discovery_url: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials, + discoveryServiceUrl=self._discovery_url) + else: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials) + def _requestComputeMetadata(self, path): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) @@ -57,6 +86,8 @@ class TPUClusterResolver(ClusterResolver): return compat.as_bytes(resp.read()) def _shouldResolve(self): + if isinstance(self._should_resolve_override, bool): + return self._should_resolve_override if (self._tpu == compat.as_bytes('') or self._tpu == compat.as_bytes('local') or self._tpu.startswith(compat.as_bytes('/bns')) or @@ -81,7 +112,7 @@ class TPUClusterResolver(ClusterResolver): return None @staticmethod - def _discoveryUrl(): + def _environmentDiscoveryUrl(): return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, @@ -153,55 +184,80 @@ class TPUClusterResolver(ClusterResolver): raise ValueError('Please provide a TPU Name to connect to.') self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes - self._job_name = job_name - self._credentials = credentials + # By default the task_type is 'worker` and the task_index is 0 (which is the + # first worker in the task). + self.task_type = job_name + self.task_index = 0 + + if tpu.startswith('grpc://'): + # Cloud environment, where we are using GRPC to communicate to TPUs. + self._environment = '' + elif tpu == 'local' or not tpu: + # Google environment, where the TPU is attached to the host. + self._environment = 'google' + elif tpu.startswith('/bns'): + # Google environment, where we reach the TPU through BNS. + self._environment = 'google' + + # If TPU is in the Google environment or exists locally, we don't use any + # RPC layer. + if tpu.startswith('/bns') or tpu == 'local' or not tpu: + self.rpc_layer = None + else: + self.rpc_layer = 'grpc' + + # Setting this overrides the return value of self._shouldResolve() + self._should_resolve_override = None + + # We strip out the protocol if it is included, and override the + # shouldResolve function to never resolve. We are adding the protocol back + # in later in self.master(). + if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'): + tpu = tpu[len(self.rpc_layer + '://'):] + self._tpu = tpu + self._should_resolve_override = False + + # Whether we should actually attempt to contact Cloud APIs should_resolve = self._shouldResolve() + # We error out if we are in a non-Cloud environment which cannot talk to the + # Cloud APIs using the standard class and a special object is not passed in. + self._service = service + if (self._service is None and should_resolve and + not _GOOGLE_API_CLIENT_INSTALLED): + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + # We save user-passed credentials, unless the user didn't pass in anything. + self._credentials = credentials + if (credentials == 'default' and should_resolve and + _GOOGLE_API_CLIENT_INSTALLED): + self._credentials = None + + # Automatically detect project and zone if unspecified. if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] - self._project = project self._zone = zone - if credentials == 'default' and should_resolve: - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None and should_resolve: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - final_discovery_url = self._discoveryUrl() or discovery_url - if final_discovery_url: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials, - discoveryServiceUrl=final_discovery_url) - else: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) - else: - self._service = service + self._discovery_url = self._environmentDiscoveryUrl() or discovery_url self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and (should_resolve or - in_gke): + if (coordinator_name and not coordinator_address and + (should_resolve or in_gke)): self._start_local_server() else: self._coordinator_address = coordinator_address - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Get the Master string to be used for the session. In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of @@ -213,8 +269,12 @@ class TPUClusterResolver(ClusterResolver): 'grpc://10.240.1.2:8470' will be returned). Args: - task_type: (Optional) The type of the TensorFlow task of the master. - task_index: (Optional) The index of the TensorFlow task of the master. + task_type: (Optional, string) The type of the TensorFlow task of the + master. + task_index: (Optional, integer) The index of the TensorFlow task of the + master. + rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to + communicate with TPUs. Returns: string, the connection string to use when creating a session. @@ -222,25 +282,34 @@ class TPUClusterResolver(ClusterResolver): Raises: ValueError: If none of the TPUs specified exists. """ - if not self._shouldResolve(): - return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] - - cluster_spec = self.cluster_spec() - if task_type and task_index: - return cluster_spec.task_address(task_type, task_index) - - job_tasks = cluster_spec.job_tasks(self._job_name) - if not job_tasks: - raise ValueError('No TPUs exists with the specified names exist.') - - return 'grpc://' + job_tasks[0] + if self._shouldResolve(): + # We are going to communicate with the Cloud TPU APIs to get a Cluster. + cluster_spec = self.cluster_spec() + if task_type is not None and task_index is not None: + # task_type and task_index is from the function parameter + master = cluster_spec.task_address(task_type, task_index) + elif self.task_type is not None and self.task_index is not None: + # task_type and task_index is from the object + master = cluster_spec.task_address(self.task_type, self.task_index) + else: + # by default we take the first item in the cluster with the right name + job_tasks = cluster_spec.job_tasks(self.task_type) + if not job_tasks: + raise ValueError('No TPUs with the specified names exist.') + master = job_tasks[0] + else: + if isinstance(self._tpu, (bytes, bytearray)): + master = self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] + else: + master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0] + return format_master_url(master, rpc_layer or self.rpc_layer) def get_master(self): return self.master() def get_job_name(self): if self._shouldResolve(): - return self._job_name + return self.task_type def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -270,7 +339,8 @@ class TPUClusterResolver(ClusterResolver): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) + service = self._tpuService() + request = service.projects().locations().nodes().get(name=full_name) response = request.execute() if 'state' in response and response['state'] != 'READY': @@ -291,18 +361,23 @@ class TPUClusterResolver(ClusterResolver): instance_url = '%s:%s' % (response['ipAddress'], response['port']) worker_list = [instance_url] - cluster_spec = {self._job_name: worker_list} + cluster_spec = {self.task_type: worker_list} else: - if not self._tpu.startswith(compat.as_bytes('grpc://')): + if self.rpc_layer is None: # Case 3. return None # Case 2. - cluster_spec = { - self._job_name: [ - x[len(compat.as_bytes('grpc://')):] - for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) - ] - } + tpus = [] + for tpu in self._tpu.split(_ENDPOINTS_SEPARATOR): + # We are working around the fact that GKE environment variable that is + # supplied to us has the protocol string embedded in it, but we want + # to strip it out for the ClusterSpec. + if (self.rpc_layer is not None and + tpu.startswith(self.rpc_layer + '://')): + tpus.append(tpu[len(self.rpc_layer + '://'):]) + else: + tpus.append(tpu) + cluster_spec = {self.task_type: tpus} if self._coordinator_address: # {1, 2}.a @@ -310,6 +385,24 @@ class TPUClusterResolver(ClusterResolver): return server_lib.ClusterSpec(cluster_spec) + def num_accelerators_per_worker(self, session_config=None): + """Returns the number of TPU cores per worker. + + This defaults to 8 for all current TPU configurations, and we do not need + to query any remote systems for this. + + Args: + session_config: Unused. Not currently necessary to query anything as this + number is 8 for all TPU configurations. + """ + del session_config # Unused. Not necessary to query anything. + return 8 + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in.""" + return self._environment + def _start_local_server(self): address = self._requestComputeMetadata('instance/network-interfaces/0/ip') self._server = server_lib.Server( diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index ad4f6432630be44a7de6e778f55f1fb7fd66f307..365bd52ee254b38588b3dfb20d64f7839e720df4 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -132,6 +132,7 @@ class TPUClusterResolverTest(test.TestCase): } """ % tpu_cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470') @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) @@ -157,6 +158,7 @@ class TPUClusterResolverTest(test.TestCase): job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470') @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) @@ -226,6 +228,7 @@ class TPUClusterResolverTest(test.TestCase): job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470') def testNewNetworkEndpointFormat(self): tpu_map = { @@ -304,6 +307,7 @@ class TPUClusterResolverTest(test.TestCase): } """ % tpu_cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470') def testPodResolutionNoCoordinator(self): tpu_map = { @@ -350,6 +354,7 @@ class TPUClusterResolverTest(test.TestCase): } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470') def testGetMasterNoEntries(self): tpu_map = {} @@ -459,10 +464,67 @@ class TPUClusterResolverTest(test.TestCase): del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - def testDiscoveryUrl(self): + def testEnvironmentDiscoveryUrl(self): os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' self.assertEqual('https://{api}.internal/{apiVersion}', - TPUClusterResolver._discoveryUrl()) + TPUClusterResolver._environmentDiscoveryUrl()) + + def testEnvironmentAndRpcDetectionForGoogle(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/ab/cd/ef') + self.assertEqual(tpu_cluster_resolver.environment, 'google') + self.assertEqual(tpu_cluster_resolver.rpc_layer, None) + + def testEnvironmentAndRpcDetectionForGrpcString(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='grpc://10.1.2.3:8470') + self.assertEqual(tpu_cluster_resolver.environment, '') + self.assertEqual(tpu_cluster_resolver.rpc_layer, 'grpc') + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.1.2.3:8470') + + def testOverrideTaskTypeAndIndexAndGetMaster(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.4:8470') + + tpu_cluster_resolver.task_type = 'worker' + tpu_cluster_resolver.task_index = 3 + self.assertEqual(tpu_cluster_resolver.master(), 'grpc://10.2.3.7:8470') + + self.assertEqual( + tpu_cluster_resolver.master( + task_type='worker', task_index=2, rpc_layer='test'), + 'test://10.2.3.6:8470') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index fbdca497fcc3126d2086d289ebdb113370072d22..a63366e1361effe20787c197eddd66b5c0c96410 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -59,8 +59,6 @@ option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires M # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) -set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") -set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) @@ -72,25 +70,25 @@ endif() if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option # for targets that link ${CMAKE_THREAD_LIBS_INIT}. - find_package (Threads) + find_package (Threads REQUIRED) # Options for linking CUDA/CUDNN libraries - option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) + option(tensorflow_PATH_CUDA_LIB "Additional library search path for cudnn, nccl, culibos" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) - option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) - option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_CUDNN_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_CUDNN_LIB) + option(tensorflow_PATH_NCCL_LIB "Override PATH_CUDA_LIB for nccl" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_NCCL_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_NCCL_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_NCCL_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -210,14 +208,17 @@ endif() include(CheckCXXCompilerFlag) # OpenMP Support -CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) -if (GCC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") -endif() -CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) -if (MSVC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") -endif() +if (WIN32) + CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) + if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") + endif() +else (WIN32) + CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) + if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + endif() +endif (WIN32) # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) @@ -377,29 +378,19 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # later command will make use of the value in tensorflow_CUDA_VERSION - find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) - - # Test compatibility of compiler on CUDA - try_compile(CUDA_TEST_COMPILE_C - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - try_compile(CUDA_TEST_COMPILE_CXX - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) - message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + # minimum 9.1 in cuda version + find_package(CUDA 9.1 REQUIRED) + if(NOT CUDA_FOUND) + message(FATAL_ERROR "CUDA not found.") endif() - # by default we assume compute cabability 3.5 and 5.2. If you change this change it in - # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") + # use cmake internal CUDA_ARCH_NAME switch + # e.g. CUDA_ARCH_NAME="Auto" will autodetect + # CUDA_ARCH_NAME="All" will use all arches + cuda_select_nvcc_arch_flags(NVCC_ARCH_FLAGS ${CUDA_ARCH_NAME}) + list(APPEND CUDA_NVCC_FLAGS ${NVCC_ARCH_FLAGS}) + message(STATUS "Using CUDA arch flags: ${NVCC_ARCH_FLAGS_readable}") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) @@ -423,43 +414,94 @@ if (tensorflow_ENABLE_GPU) else (WIN32) set(CUDNN_INCLUDE "${tensorflow_CUDNN_INCLUDE}") - find_library(nccl_STATIC_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT nccl_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl.so PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT nccl_LIBRARY) message(FATAL_ERROR "NCCL is required for GPU-build") - else (NOT nccl_STATIC_LIBRARY) - message("nccl-static: ${nccl_STATIC_LIBRARY}") + else (NOT nccl_LIBRARY) + message("nccl: ${nccl_LIBRARY}") # something like /usr/lib64/libnccl_static.a - endif (NOT nccl_STATIC_LIBRARY) - - find_library(cudnn_STATIC_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT cudnn_STATIC_LIBRARY) + endif (NOT nccl_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn.so PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT cudnn_LIBRARY) message(FATAL_ERROR "CUDNN is required for GPU-build") - else (NOT cudnn_STATIC_LIBRARY) - message("cudnn-static: ${cudnn_STATIC_LIBRARY}") - endif (NOT cudnn_STATIC_LIBRARY) - - find_library(culibos_STATIC_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT culibos_STATIC_LIBRARY) + else (NOT cudnn_LIBRARY) + file(READ ${CUDNN_INCLUDE}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + # fetch cudnn version + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") + if(NOT CUDNN_VERSION_MAJOR) + set(CUDNN_VERSION "???") + else() + set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") + endif() + message(STATUS "cudnn library: ${cudnn_LIBRARY} (found version: \"${CUDNN_VERSION}\")") + endif (NOT cudnn_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + # shared first (if exists) else static one + find_library(culibos_LIBRARY NAMES libculibos.so libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + # only static version + find_library(culibos_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT culibos_LIBRARY) message(FATAL_ERROR "CULIBOS is required for GPU-build") - else (NOT culibos_STATIC_LIBRARY) - message("culibos-static: ${culibos_STATIC_LIBRARY}") - endif (NOT culibos_STATIC_LIBRARY) + else (NOT culibos_LIBRARY) + message("culibos: ${culibos_LIBRARY}") + endif (NOT culibos_LIBRARY) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} - ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) + ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_LIBRARY} ${culibos_LIBRARY} ${nccl_LIBRARY}) endif (WIN32) include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. - string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + string(REPLACE "." "" short_CUDA_VER ${CUDA_VERSION}) + + # List of enumerated CUDA caps + string(REPLACE " " ";" NVCC_ARCH_LIST "${NVCC_ARCH_FLAGS_readable}") + set(list ${NVCC_ARCH_LIST}) + + # Construct capability string + foreach(NVCC_ARCH ${NVCC_ARCH_LIST}) + if (NVCC_ARCH MATCHES "sm_") + string(REGEX REPLACE "^.sm*" "" NVCC_ARCH ${NVCC_ARCH}) + math(EXPR NVCC_ARCH_MAJOR "${NVCC_ARCH} / 10") + math(EXPR NVCC_ARCH_MINOR "(${NVCC_ARCH} - (${NVCC_ARCH_MAJOR}*10))") + if (TF_CUDA_CAP) + set(TF_CUDA_CAP "${TF_CUDA_CAP},CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + else (TF_CUDA_CAP) + set(TF_CUDA_CAP "CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + endif (TF_CUDA_CAP) + endif() + endforeach() # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" + "#define TF_CUDA_CAPABILITIES ${TF_CUDA_CAP}\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" - "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" + "#define TF_CUDNN_VERSION \"64_${CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -494,14 +536,14 @@ if (tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll cudart_dll_name=cudart64_${short_CUDA_VER}.dll - cuda_version_number=${tensorflow_CUDA_VERSION} + cuda_version_number=${CUDA_VERSION} nvcuda_dll_name=nvcuda.dll cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=${tensorflow_CUDA_VERSION} - cudnn_version_number=${tensorflow_CUDNN_VERSION}) + cuda_version_number=${CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index c6c5021f60b38ed05a19f3e439c9810251841f76..4546dbdecc0dbc36f17cc727345e0762718b5165 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -20,6 +20,7 @@ if (systemlib_ABSEIL_CPP) absl_dynamic_annotations absl_malloc_internal absl_throw_delegate + absl_int128 absl_strings str_format_internal absl_bad_optional_access) @@ -50,6 +51,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) @@ -60,6 +62,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) @@ -71,6 +74,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/libabsl_dynamic_annotations.a ${abseil_cpp_BUILD}/absl/base/libabsl_malloc_internal.a ${abseil_cpp_BUILD}/absl/base/libabsl_throw_delegate.a + ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index d94b703700cfcd9ecae7f1d2718ba33ffd82c176..96160568fa79291a7b391761373e1eaf0f70974e 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -57,6 +57,7 @@ tensorflow/python/ops tensorflow/python/ops/distributions tensorflow/python/ops/linalg tensorflow/python/ops/losses +tensorflow/python/ops/signal tensorflow/python/platform tensorflow/python/profiler tensorflow/python/profiler/internal @@ -377,8 +378,6 @@ tensorflow/contrib/seq2seq/python/ops tensorflow/contrib/session_bundle tensorflow/contrib/session_bundle/example tensorflow/contrib/signal -tensorflow/contrib/signal/python -tensorflow/contrib/signal/python/ops tensorflow/contrib/slim tensorflow/contrib/slim/python tensorflow/contrib/slim/python/slim diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 88b4a6165c0f9171ec7cc169bc099c7db1549ee7..d66e39ac07c7b7c9423fa7e878a9cefd94b867bd 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,14 +68,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index ef337b3a15c4f58fe183af78d34376b3ed27099a..9cfa8b90749280b6aa815cc210941c75bd5e16c5 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -89,7 +89,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index ef487d3509bf3c9bfaf0b117998e6b121543c1c6..df7b854afcca1a0bed660624152f465d4bf3b25f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -373,8 +373,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 8d13dc7316a693657f1b6e102830808d35372fe9..3b49755afcf0753d31c0ce506dce42709b1ee8bc 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -49,7 +48,7 @@ class XLACompileContextTest(test.TestCase): histogram_summary = summary.histogram('histogram_summary', dummy_tensor) image_summary = summary.image('image_summary', dummy_tensor) scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) summary.merge( [ audio_summary, histogram_summary, image_summary, scalar_summary, diff --git a/tensorflow/contrib/copy_graph/python/__init__.py b/tensorflow/contrib/copy_graph/python/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/__init__.py +++ b/tensorflow/contrib/copy_graph/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/copy_graph/python/util/__init__.py b/tensorflow/contrib/copy_graph/python/util/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/util/__init__.py +++ b/tensorflow/contrib/copy_graph/python/util/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index fe5e34d258fbc1508a0a85655f29c2c9bc8fa8b1..d53549048f33162ec89dfe957ca58a4bbb4e95c6 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """Linear-chain CRF layer. -See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide. - @@crf_binary_score @@crf_decode @@crf_log_likelihood diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 57ffaa87e45559a6ecf4c8059e5a6cdee8b8b664..8d35622e393e15a2f2dfea7c75ad2c9f48aa7150 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -42,10 +42,11 @@ tf_custom_op_py_library( cuda_py_test( name = "cudnn_rnn_ops_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ ":cudnn_rnn_py", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", @@ -61,10 +62,10 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - shard_count = 6, + shard_count = 2, tags = [ - "no_oss", # b/117989214 "noasan", # http://b/62067814 + "requires-gpu-sm35", ], ) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 5d8c6191f8db9f96532aa78e4790a4665d3b4877..5320232268657fa73bcd3e86da49d6525e9b8db5 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -24,6 +24,10 @@ @@CudnnGRUSaveable @@CudnnRNNReluSaveable @@CudnnRNNTanhSaveable +@@CudnnParamsFormatConverterLSTM +@@CudnnParamsFormatConverterGRU +@@CudnnParamsFormatConverterTanh +@@CudnnParamsFormatConverterRelu """ from __future__ import absolute_import @@ -48,6 +52,10 @@ _allowed_symbols = [ "CudnnGRUSaveable", "CudnnRNNReluSaveable", "CudnnRNNTanhSaveable", + "CudnnParamsFormatConverterLSTM", + "CudnnParamsFormatConverterGRU", + "CudnnParamsFormatConverterTanh", + "CudnnParamsFormatConverterRelu", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index c59d3682d404e032d9f4bf81ef54ab456341cefa..1e2c9121d63267692ee80f14299392e19ab95a88 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import itertools import os import unittest +from absl.testing import parameterized import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -56,710 +62,991 @@ CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER -def _CreateModel(rnn_mode, - num_layers, - num_units, - input_size, - input_mode="linear_input", - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0.): - del input_mode - if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM: - model_fn = cudnn_rnn_ops.CudnnLSTM - elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU: - model_fn = cudnn_rnn_ops.CudnnGRU - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH: - model_fn = cudnn_rnn_ops.CudnnRNNTanh - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU: - model_fn = cudnn_rnn_ops.CudnnRNNRelu +def RunLSTM(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_c_op = variable_scope.get_variable( + "initial_c_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + + with variable_scope.variable_scope("test", initializer=initializer): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + num_units, num_units * 4], + dtype=dtype) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + + # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. + cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + outputs_op, state_tuple_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=rnn_cell_impl.LSTMStateTuple( + h=initial_h_op, c=initial_c_op), + dtype=dtype, + time_major=True, + scope=None) + + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + cu_initial_c_op, + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + # Remove the trivial 1st dimension. + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0), + h=array_ops.squeeze(cu_h_op, axis=0)) + + if is_training: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + + (cu_inp_grad_op, cu_hgrad_op, + cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, + [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + # Remove the trivial 1st dimension + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + cu_wgrad_op = cu_wgrad_op[0] + cu_bgrad_op = cu_bgrad_op[0] + # cudnn lstm has 2 biases each gate. When converting to tf canonical format, + # the two biases are summed into one. Thus here bias gradient should be + # halved when comparing with tf lstm. + cu_bgrad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "state_grad: %s" % str(state_grad)) + logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: - raise ValueError("Invalid rnn_mode: %s" % rnn_mode) - return model_fn( - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - - -def _CreateParamsSavable(params, - model, - base_variable_scope=None, - name="params_canonical"): - """Create a RNNParamsSaveable for the weight and bias parameters. + outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + return outputs, cu_outputs, state_tuple, cu_state_tuple + + +# Basic set of RNN configs to test. They can be further extended in relevant +# test (e.g. adding num_dirs). +NAMED_RNN_TESTCASES = ({ + "testcase_name": "xsmall", + "num_units": 1, + "input_size": 1, + "batch_size": 1, + "time": 1, + "num_layers": 1, +}, { + "testcase_name": "small", + "num_units": 4, + "input_size": 4, + "batch_size": 4, + "time": 4, + "num_layers": 1, +}, { + "testcase_name": "medium", + "num_units": 128, + "input_size": 64, + "batch_size": 8, + "time": 16, + "num_layers": 1, +}, { + "testcase_name": "large", + "num_units": 128, + "input_size": 128, + "batch_size": 16, + "time": 32, + "num_layers": 1, +}) + + +def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): + """Expands testcase with new config dimensions. + + Example: + inputs = ( + {'testcase_name': 'test1', 'gender': 'male'} + {'testcase_name': 'test2', 'gender': 'female'} + ) + remove_keys: empty + extra_configs = { + 'age': [40, 80] + 'height': [5, 6] + } + + Returns: + ( + {'testcase_name': 'test1_age_40_height_5','gender': 'male', 'age': + 40,'height': 5} + {'testcase_name': 'test1_age_40_height_6', 'gender': 'male', 'age': 40, + 'height': 6} + {'testcase_name': 'test1_age_80_height_5', 'gender': 'male', 'age': 80, + 'height': 5} + {'testcase_name': 'test1_age_80_height_6', 'gender': 'male', 'age': 80, + 'height': 6} + + {'testcase_name': 'test2_age_40_height_5', 'gender': 'female', 'age': + 40, + 'height': 5} + {'testcase_name': 'test2_age_40_height_6', 'gender': 'female', 'age': + 40, + 'height': 6} + {'testcase_name': 'test2_age_80_height_5', 'gender': 'female', 'age': + 80, + 'height': 5} + {'testcase_name': 'test2_age_80_height_6', 'gender': 'female', 'age': + 80, + 'height': 6} + ) Args: - params: a Variable for weight and bias parameters. - model: a CudnnRNN model. - base_variable_scope: a string, prefix of names of saved variables. - name: a string, name of the RNNParamsSaveable object. + inputs: A list of dictionary, each being a testcase. + *remove_keys: A list of keys into testcase which are not needed in new + testcases. + **extra_configs: A dict of new test dimension and applicable values in that + dimension. + Returns: - a RNNParamsSaveable object. + A list of dictionary with expanded test cases. """ - if model._rnn_mode == CUDNN_LSTM: - fn = cudnn_rnn_ops.CudnnLSTMSaveable - elif model._rnn_mode == CUDNN_GRU: - fn = cudnn_rnn_ops.CudnnGRUSaveable - elif model._rnn_mode == CUDNN_RNN_TANH: - fn = cudnn_rnn_ops.CudnnRNNTanhSaveable - elif model._rnn_mode == CUDNN_RNN_RELU: - fn = cudnn_rnn_ops.CudnnRNNReluSaveable - params_saveable = fn( - params, - model.num_layers, - model.num_units, - model.input_size, - model.input_mode, - model.direction, - scope=base_variable_scope, - name=name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) - return params_saveable - - -def _MinLSTMParamSize(num_layers, - num_units, - input_size, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION): - if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units - all_biases = 8 * num_layers * num_units - return first_layer_weights + higher_layer_weights + all_biases - elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = (num_layers - 1) * ( - 4 * 2 * num_units * num_units + 4 * num_units**2) - all_biases = 8 * num_layers * num_units - return 2 * (first_layer_weights + higher_layer_weights + all_biases) - else: - raise ValueError("%s direction is not supported.") + res = [] + ordered_extra_configs = collections.OrderedDict(extra_configs) + keys = ordered_extra_configs.keys() + # A list of list of configs. + # The outer loop is iterating keys, the innner is values of one key. + combined_kv = [[(k, v) for v in ordered_extra_configs[k]] for k in keys] + logging.info("combined_kv: %s", combined_kv) + for inp in inputs: + # Each inp is a dict + for config in itertools.product(*combined_kv): + new_inp = dict(inp) + # config is a list in the form of [(k_i, v_j), (k_p, v_q), ...] + suffix = ["%s_%s" % (p[0], str(p[1])) for p in config] + suffix = "_".join(suffix) + new_inp["testcase_name"] += "_" + suffix + for k, v in config: + new_inp[k] = v + # Remove not used keys from the new test case. + if remove_keys: + if not isinstance(remove_keys, (list, tuple)): + remove_keys = [remove_keys] + for k in remove_keys: + new_inp.pop(k, None) + logging.info("new_inp: %s", new_inp) + res.append(new_inp) + # Dedup, necessary if `remove_keys` is set. + return [dict(t) for t in {tuple(d.items()) for d in res}] -class CudnnRNNTestSaveRestore(TensorFlowTestCase): - def _CompareWeights(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - for lw, rw in zip(lhs, rhs): - self.assertAllEqual(lw, rw) +class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): - def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): - self.assertEqual(len(lhs), len(rhs)) - if rnn_mode == CUDNN_LSTM: - num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_GRU: - num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_RNN_TANH: - num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER - else: - num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER - num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 - num_params_per_layer *= num_dirs - self.assertEqual(num_params_per_layer * num_layers, len(lhs)) - - for i in range(num_layers): - layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - if direction == CUDNN_RNN_UNIDIRECTION: - self._CompareSingleLayerBiases(layer_lhs, layer_rhs) - else: - size = len(layer_lhs) - fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] - fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] - self._CompareSingleLayerBiases(fw_lhs, fw_rhs) - self._CompareSingleLayerBiases(bw_lhs, bw_rhs) - - def _CompareSingleLayerBiases(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - - lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] - lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] - self.assertEqual(len(lf_lhs), len(rt_lhs)) - self.assertEqual(len(lf_rhs), len(rt_rhs)) - - sum_lhs, sum_rhs = [], [] - for lf, rt in zip(lf_lhs, rt_lhs): - sum_lhs.append(lf + rt) - for lf, rt in zip(lf_rhs, rt_rhs): - sum_rhs.append(lf + rt) - self.assertEqual(len(sum_lhs), len(sum_rhs)) - for lf, rt in zip(sum_lhs, sum_rhs): - self.assertAllEqual(lf, rt) + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, + state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers) - def _testSaveRestoreVariable(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - params = variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) - saveable = _CreateParamsSavable(params, model) - weights, biases = saveable._OpaqueParamsToCanonical() - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + for s, cu_s in zip(state_tuple, cu_state_tuple): + self.assertAllClose(s, cu_s, rtol=rtol, atol=atol) + for sg, cu_sg in zip(state_grad, cu_state_grad): + self.assertAllClose(sg, cu_sg, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) + self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - weights_v, biases_v = sess.run([weights, biases]) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) - sess.run(reset_params) - saver.restore(sess, save_path) - weights_v_restored, biases_v_restored = sess.run([weights, biases]) - - self._CompareWeights(weights_v, weights_v_restored) - self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - names = ["rnn_1", "rnn_2"] - param_vars = [ - variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) for name in names - ] - saveables = [] - for name, params in zip(names, param_vars): - saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0]._OpaqueParamsToCanonical() - weights2, biases2 = saveables[1]._OpaqueParamsToCanonical() - reset_params = [ - state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) for params in param_vars - ] - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session(use_gpu=True, - graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - weights1_v, biases1_v = sess.run([weights1, biases1]) - weights2_v, biases2_v = sess.run([weights2, biases2]) - - sess.run(reset_params) - saver.restore(sess, save_path) - weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) - weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) - - self._CompareWeights(weights1_v, weights1_v_restored) - self._CompareWeights(weights2_v, weights2_v_restored) - self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, - direction) - self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreOutput(self, rnn_mode, direction, dtype): - with ops.Graph().as_default(): - num_layers = 2 - num_units = 7 - input_size = 7 - seq_length = 10 - batch_size = 5 - dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 - model = _CreateModel( - rnn_mode, + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, num_layers, + is_training=False) + + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, num_units, input_size, - direction=direction, - dtype=dtype) - params_size_t = model.params_size() - params = variables.Variable( - array_ops.ones([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - _CreateParamsSavable(params, model) - save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) - np.random.seed(1234) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - input_data = constant_op.constant( - np.random.randn(seq_length, batch_size, input_size), dtype=dtype) - input_h = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - if has_input_c: - input_c = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - outputs = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - outputs = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - total_sum = sum(map(math_ops.reduce_sum, outputs)) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run(total_sum) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - sess.run(reset_params) - saver.restore(sess, save_path) - total_sum_v_restored = sess.run(total_sum) - self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSaveRestore(self): - rnn_modes = [ - cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU, - cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - dtype_list = [dtypes.float32, dtypes.float64] - for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions, - dtype_list): - self._testSaveRestoreVariable(rnn_mode, direction, dtype) - self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype) - self._testSaveRestoreOutput(rnn_mode, direction, dtype) - - -class CudnnRNNTestParamsSize(TensorFlowTestCase): - - def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, - direction): - logging.info("Testing one lstm param size with config: %s", locals()) - min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size, - direction) - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - num_layers, + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + + +def RunGRU(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + with variable_scope.variable_scope("test", initializer=initializer): + gate_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/kernel", + shape=[input_size + num_units, num_units * 2], + dtype=dtype) + gate_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/bias", + shape=[num_units * 2], + dtype=dtype) + candidate_inp_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/kernel", + shape=[input_size, num_units], + dtype=dtype) + candidate_inp_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/bias", + shape=[num_units], + dtype=dtype) + candidate_hid_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/kernel", + shape=[num_units, num_units], + dtype=dtype) + candidate_hid_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/bias", + shape=[num_units], + dtype=dtype) + + cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) + outputs_op, h_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=initial_h_op, + dtype=dtype, + time_major=True, + scope=None) + + ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] + bs = [gate_bias, candidate_inp_bias, candidate_hid_bias] + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + array_ops.zeros_like(cu_initial_h_op), # not used + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_GRU) + + if is_training: + (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, + cib_grad_op, chb_grad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op] + ws + bs) + + (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op) = cu_wgrad_op + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) = cu_bgrad_op + # cudnn gru has 2 biases for reset and update gates. When converting to tf + # canonical format, the two biases are summed into one. Thus here relevant + # bias gradient should be halved before comparing with tf gru. + cu_gb_grad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, h, inp_grad, hgrad, wgrad, bgrad = sess.run([ + outputs_op, h_op, inp_grad_op, hgrad_op, + (gk_grad_op, cik_grad_op, chk_grad_op), + (gb_grad_op, cib_grad_op, chb_grad_op) + ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ]) + # Remove the trivial 1st dimension + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "hgrad: %s" % hgrad) + logging.vlog(1, "cu_hgrad: %s" % cu_hgrad) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) + else: + outputs, h = sess.run([outputs_op, h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + # Remove the trivial 1st dimension. + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + return outputs, cu_outputs, h, cu_h + + +class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): + + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, num_units, input_size, batch_size, time, num_layers) + + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + self.assertAllClose(hgrad, cu_hgrad, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + for bg, cu_bg in zip(bgrad, cu_bgrad): + self.assertAllClose(bg, cu_bg, rtol=rtol, atol=atol) + for wg, cu_wg in zip(wgrad, cu_wgrad): + self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( num_units, input_size, - direction=direction) - params_size = model.params_size() - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size) - self.assertLessEqual(min_params_size, params_size_v) + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSize(self): - test_configs = [ - [4, 200, 200], - [4, 200, 300], - [4, 200, 100], - [1, 100, 200], - [2, 200, 100], - [3, 200, 400], - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - for (config, direction) in itertools.product(test_configs, directions): - num_layers, num_units, input_size = config - with ops.Graph().as_default(): - self._testOneLSTMParamsSize(num_layers, num_units, input_size, - direction) + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False) + self.assertAllClose(outputs, cu_outputs) + self.assertAllClose(h, cu_h) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSizeShape(self): - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - constant_op.constant([4]), 200, 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, constant_op.constant([200]), 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) + + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + # Hand-picked dropouts are used below (0. and 1.) + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_h2) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + self.assertAllClose(cu_h[0], cu_h2[0]) + + +class CudnnParamsFormatConverterTest(TensorFlowTestCase, + parameterized.TestCase): + """Class for testing various format converters.""" + + def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + w = constant_op.constant( + np.random.rand(input_size + num_units, 4 * num_units), + dtype=dtypes.float32) + b = constant_op.constant( + np.random.rand(4 * num_units), dtype=dtypes.float32) + ws.append(w) + bs.append(b) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( cudnn_rnn_ops.CUDNN_LSTM, - 4, 200, constant_op.constant([200]), - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + num_layers, + num_units, + input_size, + direction=direction) + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) -class CudnnRNNTestInference(TensorFlowTestCase): + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) - def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, - expected, tolerance): - random_seed.set_random_seed(5678) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - input_mode="auto_select", - direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION), - dropout=dropout) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - params_size_t = model.params_size() - input_data = array_ops.ones([seq_length, batch_size, input_size]) - input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.Variable( - array_ops.ones([params_size_t]), validate_shape=False) - if has_input_c: - input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - output, output_h = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run([total_sum]) + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = ( + np.sum([x.size for x in ws]) + np.sum([x.size for x in bs])) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) - self.assertAllClose( - total_sum_v[0], expected, atol=tolerance, rtol=tolerance) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_lstm(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleInference(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "expected": 231833.22, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "expected": 130688, - "tolerance": 1e-2, - "shape": { - "num_layers": 2, - "num_units": 8, - "input_size": 4, - "batch_size": 4, - "seq_length": 2, - "dir_count": 1, - }, - }, - ] - # Cudnn scales result for dropout during training, therefore dropout has no - # impact for inference results. - # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most - # demonstrative of the dropout-invariant nature of CudnnRnn.) - dropouts = [0., 0.5, 1.] - for (config, dropout) in itertools.product(test_configs, dropouts): - rnn_mode = config["rnn_mode"] - expected = config["expected"] - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleInference( - rnn_mode, shape["num_layers"], shape["num_units"], - shape["input_size"], shape["batch_size"], shape["seq_length"], - shape["dir_count"], dropout, expected, tolerance) - - -class CudnnRNNTestTraining(TensorFlowTestCase): - - def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): - # Gradient checking runs two forward ops with almost the same input. Need to - # make sure the drop patterns across the two runs are the same. - logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - random_seed.set_random_seed(5678) - direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - params_size_t = model.params_size() - input_data = variables.Variable( - random_ops.random_uniform( - [seq_length, batch_size, input_size], dtype=dtype), - dtype=dtype) - input_h = variables.Variable( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - params = variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - if has_input_c: - input_c = variables.Variable( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params) - else: - output, output_h = model( - input_data=input_data, input_h=input_h, params=params) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size_t) - inputs_and_shapes = [ - (input_data, [seq_length, batch_size, input_size]), - (input_h, [num_layers * dir_count, batch_size, num_units]), - (params, [params_size_v]), - ] - if has_input_c: - inputs_and_shapes.append( - (input_c, [num_layers * dir_count, batch_size, num_units]),) - sess.run(variables.global_variables_initializer()) - all_inputs = [entry[0] for entry in inputs_and_shapes] - all_shapes = [entry[1] for entry in inputs_and_shapes] - - err = gradient_checker.compute_gradient_error( - all_inputs, all_shapes, total_sum, [1], delta=delta) - - self.assertLess(err, tolerance) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + def test_lstm_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + def _test_gru_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size, direction=direction) + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + gate_kernel = constant_op.constant( + np.random.rand(input_size + num_units, num_units * 2), + dtype=dtypes.float32) + gate_bias = constant_op.constant( + np.random.rand(num_units * 2), dtype=dtypes.float32) + candidate_inp_kernel = constant_op.constant( + np.random.rand(input_size, num_units), dtype=dtypes.float32) + candidate_inp_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + candidate_hid_kernel = constant_op.constant( + np.random.rand(num_units, num_units), dtype=dtypes.float32) + candidate_hid_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + ws.extend([gate_kernel, candidate_inp_kernel, candidate_hid_kernel]) + bs.extend([gate_bias, candidate_inp_bias, candidate_hid_bias]) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + cudnn_rnn_ops.CUDNN_GRU, + num_layers, + num_units, + input_size, + direction=direction) + + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) + + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = ( + np.sum([x.size for x in ws]) + np.sum([x.size for x in bs])) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_gru(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTraining(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float32, - "tolerance": 1.5e-2, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float32, - "tolerance": 4e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float32, - "tolerance": 5e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float32, - "tolerance": 5e-1, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - ] - dropouts = [0., 0.5, 1.] - dir_counts = [1] - for config, dropout, dir_count in itertools.product(test_configs, dropouts, - dir_counts): - rnn_mode = config["rnn_mode"] - dtype = config.get("dtype", dtypes.float32) - delta = config.get("delta", 1e-3) - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, tolerance) + def test_gru_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + +class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): + """Class for testing various Cudnn Rnn SaveableObjects.""" + + def _create_opaque_param(self, + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name=None): + param_size_t = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + rnn_mode, num_layers, num_units, input_size, direction=direction) + init_val = random_ops.random_uniform([param_size_t]) + return variable_scope.get_variable( + name or "opaque_param", initializer=init_val, validate_shape=False) + + def _create_saveable(self, opaque_param, rnn_mode, num_units, input_size, + num_layers, direction): + if rnn_mode == CUDNN_LSTM: + fn = cudnn_rnn_ops.CudnnLSTMSaveable + elif rnn_mode == CUDNN_GRU: + fn = cudnn_rnn_ops.CudnnGRUSaveable + elif rnn_mode == CUDNN_RNN_TANH: + fn = cudnn_rnn_ops.CudnnRNNTanhSaveable + elif rnn_mode == CUDNN_RNN_RELU: + fn = cudnn_rnn_ops.CudnnRNNReluSaveable + saveable = fn( + opaque_param, num_layers, num_units, input_size, direction=direction) + return saveable + + def _compare_weights(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _compare_biases(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lf, rt in zip(lhs, rhs): + self.assertAllEqual(lf, rt) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_variable(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size, + num_layers, direction) + saveable = self._create_saveable(opaque_param, rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + weights_op, biases_op = saveable.format_converter.opaque_to_tf_canonical( + saveable._variables) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_op = state_ops.assign(opaque_param, + array_ops.zeros_like(opaque_param)) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + weights, biases = sess.run([weights_op, biases_op]) + + # Reset the opaque param value + sess.run(reset_op) + # Assert reset happened. + weights_z, biases_z = sess.run([weights_op, biases_op]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_op, biases_op]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_params = [] + saveables = [] + num_opaque_params = 2 + for i in range(num_opaque_params): + opaque_params.append( + self._create_opaque_param( + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name="opaque_param_%d" % i)) + saveable = self._create_saveable(opaque_params[i], rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saveables.append(saveable) + + weights_ops, biases_ops = [], [] + for i in range(num_opaque_params): + weights_op, biases_op = ( + saveables[i].format_converter.opaque_to_tf_canonical( + saveables[i]._variables)) + weights_ops.append(weights_op) + biases_ops.append(biases_op) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_ops = [] + for i in range(num_opaque_params): + reset_ops.append( + state_ops.assign(opaque_params[i], + array_ops.zeros_like(opaque_params[i]))) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + for i in range(num_opaque_params): + weights, biases = sess.run([weights_ops[i], biases_ops[i]]) + + # Reset the opaque param value + sess.run(reset_ops[i]) + + # Assert reset happened. + weights_z, biases_z = sess.run([weights_ops[i], biases_ops[i]]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_ops[i], biases_ops[i]]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) if __name__ == "__main__": diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 1954f6717bbebd803b0ec45992b43cf68f5d72a0..6cc93dccb004687a2d583a5d1925ea6b98c98979 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -536,7 +536,9 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): save_path = os.path.join(self.get_temp_dir(), "save-restore-variable-test") saver = saver_lib.Saver() - weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + weights, biases = ( + model.rnn.saveable.format_converter._opaque_to_cu_canonical( + model.rnn.saveable._variables)) opaque_params = rnn.trainable_variables[0] # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save # Cudnn vars in canonical format. @@ -583,8 +585,12 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): dtype=dtype) opaque_params = (model1.rnn.trainable_variables[0], model2.rnn.trainable_variables[0]) - weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() - weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + saveable1 = model1.rnn.saveable + weights1, biases1 = saveable1.format_converter._opaque_to_cu_canonical( + saveable1._variables) + saveable2 = model1.rnn.saveable + weights2, biases2 = saveable2.format_converter._opaque_to_cu_canonical( + saveable2._variables) reset_params = [ state_ops.assign(params, array_ops.zeros_like(params, dtype=dtype)) diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index f09466b631f69d6234573dd5eafada650421c117..60229af374be869005139921483793156e5e7a05 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -27,5 +27,10 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibl from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterGRU +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterLSTM +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterRelu +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterTanh from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index a324c6e7d76223aaa6514e695e4ff8444db455d0..8e25637ed91a1559b321ea96efbfaa2910f67158 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -21,6 +21,7 @@ from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -322,7 +323,7 @@ class _CudnnRNN(base_layer.Layer): raise ValueError("The last dimension of the inputs to `CudnnRNN` " "should be defined. Found `None`.") self._input_size = input_shape[-1].value - self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + self.input_spec = input_spec.InputSpec(ndim=3, axes={-1: self._input_size}) self._set_scope(None) @@ -388,11 +389,11 @@ class _CudnnRNN(base_layer.Layer): output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. Raises: - ValueError: initial_state is not a tuple. + TypeError: initial_state is not a tuple. """ if initial_state is not None and not isinstance(initial_state, tuple): - raise ValueError("Invalid initial_state type: %s, expecting tuple.", - type(initial_state)) + raise TypeError("Invalid initial_state type: %s, expecting tuple." % + initial_state) dtype = self.dtype inputs = ops.convert_to_tensor(inputs, dtype=dtype) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 2c92f31788378c2a9f01183bc04b035668b59b59..1ce29b42d52ff67477161278ed11016c2e73041d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -74,7 +74,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): - """Cudnn Compatible GRUCell. + r"""Cudnn Compatible GRUCell. A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by @@ -177,172 +177,60 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): return new_h, new_h -# TODO(yaozhang): make sure we only save the canonical version of params and -# don't save the platform-specific version to avoid potential race -# conditions where params is updated by both versions when being restored. -# Currently, checkpointing will function properly, despite that we save both -# versions, because Saver restores customized savables after Variables. -# However, it is good to not rely on this restoring order of Saver and to -# avoid unnecessary storage. Add a test to check only the canonical version is -# saved. -class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): - """Abstract SaveableObject implementation handling Cudnn opaque params.""" +class CudnnParamsFormatConverter(object): + """Abstract class that converts between params of Cudnn Rnn and TF Rnn.""" def __init__(self, - opaque_params, num_layers, num_units, input_size, input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - scope=None, - name="cudnn_rnn_saveable"): - """Creates a CudnnOpaqueParamsSaveable object. - - CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file - and is used to save/restore the weights and biases parameters in a - canonical format which is directly consumable by platform-independent tf - RNN cells. Parameters are saved as tensors layer by layer with weight - tensors followed by bias tensors, and forward direction followed by - backward direction (if applicable). When restoring, a user could name - param_variables as desired, and restore weight and bias tensors to these - variables. - - For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per - bias for each layer: tensor 0 is applied to the input from the previous - layer and tensor 1 to the recurrent input. - - For CudnnLSTM, there are 8 tensors per weight and per bias for each - layer: tensor 0-3 are applied to the input from the previous layer and - tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; - tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; - tensor 3 and 7 the output gate. - - For CudnnGRU, there are 6 tensors per weight and per bias for each layer: - tensor 0-2 are applied to the input from the previous layer and - tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; - tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + direction=CUDNN_RNN_UNIDIRECTION): + """Constructor. Args: - opaque_params: a variable, Cudnn RNN opaque params. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be one + of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' + (default) always applies a linear projection of input onto RNN hidden + state. (standard RNN behavior). * 'skip_input' is only allowed when + input_size == num_units; * 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' - scope: string of VariableScope, the scope of equivalent subgraph - consisting only platform-independent tf RNN cells. - name: the name of the CudnnOpaqueParamsSaveable object. + 'unidirectional' or 'bidirectional' """ - # Define in subclasses. self._num_layers = num_layers self._input_size = input_size self._num_units = num_units self._input_mode = input_mode self._direction = direction - if scope is not None: - scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope - self._scope = scope_name or None - else: - self._scope = None - - self._variables = opaque_params self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - weights, biases = self._OpaqueParamsToCanonical() - (weights, weight_names), (biases, bias_names) = self._TransformCanonical( - weights, biases) - # We currently don't use slice_spec. It might be useful in a distributed - # setting where each parameter server node stores a slice of variable, - # instead of having the master pull all slices and then save them. - slice_spec = "" - params = weights + biases - self._weight_names = weight_names - self._bias_names = bias_names - self._param_names = weight_names + bias_names - prefixed_param_names = weight_names + bias_names - if self._scope: - prefixed_param_names = [ - "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] - specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, prefixed_param_names) - ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) - - def restore(self, restored_tensors, restored_shapes): - weights, biases = self._ReverseTransformCanonical(restored_tensors) - weights = [array_ops.reshape(w, [-1]) for w in weights] - opaque_params = self._CanonicalToOpaqueParams(weights, biases) - - return state_ops.assign( - self._variables, opaque_params, validate_shape=False) + def tf_canonical_to_opaque(self, tf_canonicals): + r"""Converts tf canonical weights to cudnn opaque param.""" + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] + opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) + return opaque_params - def _checkpointable_save(self, save_buffer): - weights, biases = self._OpaqueParamsToCanonical() - with ops.device("gpu:0"): - (weights, _), (biases, _) = self._TransformCanonical( - weights, biases) - for name, tensor in zip(self._param_names, weights + biases): - save_buffer[name] = array_ops.identity(tensor) + def opaque_to_tf_canonical(self, opaque_param): + r"""Converts cudnn opaque param to tf canonical weights.""" + cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) + weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) + return weights, biases - def _checkpointable_restore(self, restore_buffer): - tensors = [array_ops.identity(restore_buffer[name]) - for name in self._param_names] - return self.restore( - restored_tensors=tensors, - restored_shapes=None # Unused - ) - - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. - - When saving or restoring, converts to or from the opaque buffer - format. Weights are saved and loaded in the configuration expected by - cuDNN-compatible cells. - - Args: - checkpointable: An object inheriting from `CheckpointableBase` to add - dependencies too (typically the cuDNN `Layer`). - dtype: The dtype for the canonical parameter Tensors. - """ - split_dependencies = split_dependency.split_dependency( - component_names=self._param_names, - component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) - - def _checkpointable_track_params(self, checkpointable, params): - """Tracks parameters in a canonical configuration.""" - return # NotImplementedError raised by the Layer. - - def _TFCanonicalNamePrefix(self, layer, is_fwd=True): - if self._direction == CUDNN_RNN_UNIDIRECTION: - return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) - else: - if is_fwd: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % - (layer, self._rnn_cell_name)) - else: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % - (layer, self._rnn_cell_name)) - - def _OpaqueParamsToCanonical(self): + def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. + Args: + opaque_param: An opaque tensor storing cudnn rnn params (weights and + biases). Returns: 2 list for weights and biases respectively. """ @@ -351,14 +239,14 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, - params=self._variables, + params=opaque_param, num_params=self._num_params, rnn_mode=self._rnn_mode, input_mode=self._input_mode, direction=self._direction) return (weights, biases) - def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): + def _cu_canonical_to_opaque(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. Args: @@ -378,7 +266,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): input_mode=self._input_mode, direction=self._direction) - def _TransformCanonical(self, cu_weights, cu_biases): + def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. The elements of argument lists are laid out in the following format: @@ -398,46 +286,43 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. Returns: - 2 tuples, one for weights and the other for bias. - Each tuple has two lists: the 1st for transformed tf canonical tensors - and the 2nd for the names of the tensors under which they are saved. + 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] - tf_weights_names, tf_bias_names = [], [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num for i in range(self._num_layers): - layer_weights = cu_weights[i * layer_weights_num: - (i + 1) * layer_weights_num] + layer_weights = cu_weights[i * layer_weights_num:(i + 1) * + layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - prefix = self._TFCanonicalNamePrefix(i) - self._TransformSingleLayerCanonical(layer_weights, layer_biases, prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + layer_weights, layer_biases, tf_weights, tf_biases) else: - fw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=True) - bw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=False) - fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] fw_biases = layer_biases[:len(layer_biases) // 2] bw_biases = layer_biases[len(layer_biases) // 2:] - self._TransformSingleLayerCanonical(fw_weights, fw_biases, fw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - - self._TransformSingleLayerCanonical(bw_weights, bw_biases, bw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - return (tf_weights, tf_weights_names), (tf_biases, tf_bias_names) - - def _TransformSingleLayerCanonical(self, cu_weights, cu_biases, prefix, - tf_weights, tf_weights_names, tf_biases, - tf_bias_names): + self._cu_canonical_to_tf_canonical_single_layer( + fw_weights, + fw_biases, + tf_weights, + tf_biases, + ) + + self._cu_canonical_to_tf_canonical_single_layer( + bw_weights, + bw_biases, + tf_weights, + tf_biases, + ) + return (tf_weights, tf_biases) + + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -447,15 +332,12 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Args: cu_weights: a list of tensors, single layer weights. cu_biases: a list of tensors, single layer biases. - prefix: the shared prefix of all tensor names. tf_weights: a list where transformed weights are stored. - tf_weights_names: a list where names of transformed weights are stored. tf_biases: a list where transformed biases are stored. - tf_bias_names: a list where names of transformed biases are stored. """ raise NotImplementedError("Abstract method") - def _ReverseTransformCanonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -502,30 +384,27 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return cu_weights, cu_biases def _cudnn_to_tf_weights(self, *cu_weights): - r"""Stitching cudnn canonical weights to generate tf canonical weights.""" + r"""Stitches cudnn canonical weights to generate tf canonical weights.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_weights(self, layer, *tf_weights): - r"""Reverse the operations in StitchWeights().""" + r"""Reverses the operations in StitchWeights().""" raise NotImplementedError("Abstract method") def _cudnn_to_tf_biases(self, *biases): - r"""Stitching cudnn canonical biases to generate tf canonical biases.""" + r"""Stitches cudnn canonical biases to generate tf canonical biases.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_biases(self, *tf_biases): - r"""Reverse the operations in StitchBiases().""" + r"""Reverses the operations in StitchBiases().""" raise NotImplementedError("Abstract method") -class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn LSTM opaque params.""" - +class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF LSTM.""" _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) - def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order return [i_g, c_g, f_g, o_g] @@ -603,44 +482,16 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # Return ifco order for Cudnn LSTM. return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): - (w,) = self._cudnn_to_tf_weights(*weights) - (b,) = self._cudnn_to_tf_biases(*biases) - + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): + (w,) = self._cudnn_to_tf_weights(*cu_weights) + (b,) = self._cudnn_to_tf_biases(*cu_biases) tf_weights.append(w) - tf_weights_names.append(prefix + "/kernel") - tf_biases.append(b) - tf_bias_names.append(prefix + "/bias") - - def _checkpointable_track_params(self, checkpointable, params): - """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" - biases = [] - weights = [] - for name in self._weight_names: - weights.append(params[name]) - for name in self._bias_names: - biases.append(params[name]) - assert len(params) == len(weights) + len(biases) - if len(weights) == 1 and len(biases) == 1: - # For single-layer cells, allow substituting a cell with no MultiRNNCell - # wrapping. - kernel, = weights # pylint: disable=unbalanced-tuple-unpacking - bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access - assert len(biases) == len(weights) - for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access - cell.bias = bias - cell.kernel = kernel -class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn GRU opaque params.""" +class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF GRU.""" _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER @@ -702,29 +553,18 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): b_ri, b_rr = array_ops.split(br, 2, axis=0) return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): # pylint: disable=invalid-name - W_ir, w_h, r_h = self._cudnn_to_tf_weights(*weights) - b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*biases) + W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) + b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) # pylint: enable=invalid-name - tf_weights.extend([W_ir, w_h, r_h]) - tf_weights_names.append(prefix + "/gates/kernel") - tf_weights_names.append(prefix + "/candidate/input_projection/kernel") - tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") - tf_biases.extend([b_ir, b_wh, b_rh]) - tf_bias_names.append(prefix + "/gates/bias") - tf_bias_names.append(prefix + "/candidate/input_projection/bias") - tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") - -class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) +class CudnnParamsFormatConverterBasic(CudnnParamsFormatConverterLSTM): + """Helper class that converts between params of Cudnn and TF Relu/Tanh RNN.""" def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -766,18 +606,270 @@ class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): return b_i, b_h -class CudnnRNNTanhSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" +class CudnnParamsFormatConverterTanh(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Tanh RNN.""" _rnn_mode = CUDNN_RNN_TANH _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER -class CudnnRNNReluSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Relu opaque params.""" +class CudnnParamsFormatConverterRelu(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Relu RNN.""" _rnn_mode = CUDNN_RNN_RELU _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER +# TODO(yaozhang): make sure we only save the canonical version of params and +# don't save the platform-specific version to avoid potential race +# conditions where params is updated by both versions when being restored. +# Currently, checkpointing will function properly, despite that we save both +# versions, because Saver restores customized savables after Variables. +# However, it is good to not rely on this restoring order of Saver and to +# avoid unnecessary storage. Add a test to check only the canonical version is +# saved. +class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): + """Abstract SaveableObject implementation handling Cudnn opaque params.""" + + def __init__(self, + opaque_params, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + scope=None, + name="cudnn_rnn_saveable"): + """Creates a CudnnOpaqueParamsSaveable object. + + CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file + and is used to save/restore the weights and biases parameters in a + canonical format which is directly consumable by platform-independent tf + RNN cells. Parameters are saved as tensors layer by layer with weight + tensors followed by bias tensors, and forward direction followed by + backward direction (if applicable). When restoring, a user could name + param_variables as desired, and restore weight and bias tensors to these + variables. + + For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per + bias for each layer: tensor 0 is applied to the input from the previous + layer and tensor 1 to the recurrent input. + + For CudnnLSTM, there are 8 tensors per weight and per bias for each + layer: tensor 0-3 are applied to the input from the previous layer and + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + + For CudnnGRU, there are 6 tensors per weight and per bias for each layer: + tensor 0-2 are applied to the input from the previous layer and + tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; + tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + + Args: + opaque_params: a variable, Cudnn RNN opaque params. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + scope: string of VariableScope, the scope of equivalent subgraph + consisting only platform-independent tf RNN cells. + name: the name of the CudnnOpaqueParamsSaveable object. + """ + # Define in subclasses. + self._num_layers = num_layers + self._input_size = input_size + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + if scope is not None: + scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope + self._scope = scope_name or None + else: + self._scope = None + + self._variables = opaque_params + self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + # Defined in subclasses. + self._format_converter = None + + tf_weights, tf_biases = ( + self.format_converter.opaque_to_tf_canonical(self._variables)) + tf_weight_names, tf_bias_names = self._tf_canonical_names() + # We currently don't use slice_spec. It might be useful in a distributed + # setting where each parameter server node stores a slice of variable, + # instead of having the master pull all slices and then save them. + slice_spec = "" + params = tf_weights + tf_biases + self._weight_names = tf_weight_names + self._bias_names = tf_bias_names + self._param_names = tf_weight_names + tf_bias_names + prefixed_param_names = tf_weight_names + tf_bias_names + if self._scope: + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names + ] + specs = [ + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, prefixed_param_names) + ] + super(CudnnOpaqueParamsSaveable, self).__init__( + array_ops.identity(self._variables), specs, name) + + @property + def format_converter(self): + if self._format_converter is None: + self._format_converter = self._format_converter_cls( + self._num_layers, self._num_units, self._input_size, self._input_mode, + self._direction) + return self._format_converter + + def restore(self, restored_tensors, restored_shapes): + opaque_params = self.format_converter.tf_canonical_to_opaque( + restored_tensors) + return state_ops.assign( + self._variables, opaque_params, validate_shape=False) + + def _checkpointable_save(self, save_buffer): + weights, biases = self.format_converter.opaque_to_tf_canonical( + self._variables) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [ + array_ops.identity(restore_buffer[name]) for name in self._param_names + ] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = split_dependency.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + + def _tf_canonical_names(self): + tf_weights_names, tf_biases_names = [], [] + for i in range(self._num_layers): + if self._direction == CUDNN_RNN_UNIDIRECTION: + prefix = self._tf_canonical_name_prefix(i) + self._tf_canonical_names_single_layer(prefix, tf_weights_names, + tf_biases_names) + else: + fwd_prefix = self._tf_canonical_name_prefix(i, is_fwd=True) + bak_prefix = self._tf_canonical_name_prefix(i, is_fwd=False) + + self._tf_canonical_names_single_layer(fwd_prefix, tf_weights_names, + tf_biases_names) + self._tf_canonical_names_single_layer(bak_prefix, tf_weights_names, + tf_biases_names) + return tf_weights_names, tf_biases_names + + def _tf_canonical_name_prefix(self, layer, is_fwd=True): + if self._direction == CUDNN_RNN_UNIDIRECTION: + return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) + else: + if is_fwd: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % + (layer, self._rnn_cell_name)) + else: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % + (layer, self._rnn_cell_name)) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_biases_names): + raise NotImplementedError("Abstract method") + + +class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn LSTM opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterLSTM + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/kernel") + tf_bias_names.append(prefix + "/bias") + + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + + +class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn GRU opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterGRU + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/gates/kernel") + tf_weights_names.append(prefix + "/candidate/input_projection/kernel") + tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") + + tf_bias_names.append(prefix + "/gates/bias") + tf_bias_names.append(prefix + "/candidate/input_projection/bias") + tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") + + +class CudnnRNNTanhSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterTanh + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + +class CudnnRNNReluSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterRelu + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and training. But it is possible that the layout of the parameter buffers @@ -850,7 +942,7 @@ def _get_num_params(rnn_mode, num_layers, direction): elif rnn_mode == CUDNN_RNN_TANH: num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: - raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) + raise ValueError("Invalid \'rnn_mode\': %s" % rnn_mode) num_params = num_layers * num_params_per_layer if direction != CUDNN_RNN_UNIDIRECTION: num_params *= 2 @@ -918,7 +1010,7 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 is not "1": + if use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1582,7 +1674,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): """ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s", direction) + raise ValueError("Invalid direction: %s" % direction) super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index a87a5624c88d1d0af10055261dad55937ed6aeb0..3ecd755d86f6be47910aebbdb46d335d165427d8 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -26,7 +26,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", - "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", @@ -35,6 +34,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_coordinator", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index f82453f3b5ea01b8bb64a70bd49f5e3e831bb4e2..8a8dc159ade6f2a4a9b5ec29055ea4848492b29f 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -46,6 +46,9 @@ Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` Take a very simple model consisting of a single layer: ```python +import tensorflow as tf +from tensorflow import keras + inputs = tf.keras.layers.Input(shape=(1,)) predictions = tf.keras.layers.Dense(1)(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=predictions) @@ -90,8 +93,8 @@ Similarly, we can also call `evaluate` and `predict` as before using appropriate datasets. ```python -model.evaluate(eval_dataset) -model.predict(predict_dataset) +model.evaluate(eval_dataset, steps=1) +model.predict(predict_dataset, steps=1) ``` That's all you need to train your model with Keras on multiple GPUs with @@ -131,7 +134,7 @@ def model_fn(features, labels, mode): return tf.estimator.EstimatorSpec(mode, loss=loss) if mode == tf.estimator.ModeKeys.TRAIN: - train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) ``` @@ -245,19 +248,17 @@ Let's use the same example for multi-worker. We'll start a cluster with 3 workers doing synchronous all-reduce training. In the following code snippet, we start multi-worker training using `tf.estimator.train_and_evaluate`: - ```python def model_main(): - estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig(train_distribute=distribution) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) ``` - **Note**: You don't have to set "TF\_CONFIG" manually if you use our provided Kubernetes template. @@ -324,13 +325,13 @@ start training. On your laptop, you can run ```python -estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig( experimental_distribute=tf.contrib.distribute.DistributeConfig( train_distribute=distribution, remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]})) +estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index ab2f221dc6486666e914deb19dd56c7687606e2f..8ec73654e30e4967f318c558ba94301e84a206e4 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -25,13 +25,13 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy -from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server from tensorflow.python.training.distribute import * @@ -46,6 +46,7 @@ _allowed_symbols = [ 'CrossDeviceOps', 'DistributeConfig', 'DistributionStrategy', + 'DistributionStrategyExtended', 'MirroredStrategy', 'Monitor', 'MultiWorkerAllReduce', @@ -62,6 +63,7 @@ _allowed_symbols = [ 'get_loss_reduction', 'get_replica_context', 'has_distribution_strategy', + 'in_cross_replica_context', 'require_replica_context', 'run_standard_tensorflow_server', 'UpdateContext', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 22736c799d276033c0ddc112d17e898be944c933..91282a8c1dab051da7894956d202c88c90e2fe39 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -16,45 +16,26 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # TODO(priyag): Figure out testonly issues that are preventing us from # including our tests in pip for now. -py_library( - name = "values", - srcs = ["values.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":input_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:base", - "@six_archive//:six", - ], -) - cuda_py_test( name = "values_test", srcs = ["values_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:errors", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:device_util", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", - "//tensorflow/python:device_util", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], @@ -68,25 +49,9 @@ py_library( srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":shared_variable_creator", - ":values", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_util", "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:values", ], ) @@ -95,16 +60,17 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -116,7 +82,7 @@ cuda_py_test( ":combinations", ":multi_worker_test_base", ":parameter_server_strategy", - ":values", + ":strategy_test_lib", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -127,10 +93,12 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:session", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -145,12 +113,13 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":values", - "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", "//tensorflow/python:distribute", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -161,16 +130,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":cross_tower_utils", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -233,28 +202,6 @@ py_test( ], ) -py_test( - name = "mirrored_strategy_test", - srcs = ["mirrored_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":mirrored_strategy", - ":multi_worker_test_base", - ":strategy_test_lib", - "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], @@ -270,35 +217,32 @@ py_test( ], ) +# TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", ":strategy_test_lib", - "//tensorflow/python:distribute", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 5, tags = [ "guitar", - "no_pip", "multi_and_single_gpu", - # Do not perform the extra analysis on this test, because it is already - # performed for the `:mirrored_strategy_test` target. - "no_oss", - "noasan", - "notap", - "notsan", + "no_pip", ], ) @@ -337,12 +281,15 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":one_device_strategy", - ":values", "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", ], ) @@ -352,7 +299,6 @@ cuda_py_test( additional_deps = [ ":collective_all_reduce_strategy", ":combinations", - ":cross_tower_utils", ":multi_worker_test_base", ":strategy_test_lib", "@absl_py//absl/testing:parameterized", @@ -368,15 +314,13 @@ cuda_py_test( "//tensorflow/python:layers", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], tags = [ "multi_and_single_gpu", "no_pip", - # TODO(b/118820960): Re-enable this test in guitar. - "manual", - "noguitar", ], ) @@ -470,6 +414,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -478,20 +423,11 @@ cuda_py_test( name = "keras_optimizer_v2_test", srcs = ["keras_optimizer_v2_test.py"], additional_deps = [ - ":combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", - "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/feature_column", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -509,7 +445,9 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", @@ -517,7 +455,7 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:summary", ], - shard_count = 5, + shard_count = 48, tags = [ "multi_and_single_gpu", "no_pip", @@ -525,6 +463,7 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "no_oss", # http://b/119349471 ], ) @@ -600,52 +539,16 @@ cuda_py_test( ], ) -py_library( - name = "shared_variable_creator", - srcs = ["shared_variable_creator.py"], - visibility = ["//tensorflow:internal"], -) - -py_test( - name = "shared_variable_creator_test", - srcs = ["shared_variable_creator_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":shared_variable_creator", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:test", - ], -) - -py_library( - name = "cross_tower_utils", - srcs = ["cross_tower_utils.py"], - srcs_version = "PY2AND3", - deps = [ - ":values", - "//tensorflow/contrib/all_reduce:all_reduce_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:device", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:nccl_ops", - ], -) - cuda_py_test( - name = "cross_tower_utils_test", - srcs = ["cross_tower_utils_test.py"], + name = "cross_device_utils_test", + srcs = ["cross_device_utils_test.py"], additional_deps = [ ":combinations", - ":cross_tower_utils", "@absl_py//absl/testing:parameterized", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -654,40 +557,20 @@ cuda_py_test( ], ) -py_library( - name = "cross_tower_ops", - srcs = ["cross_tower_ops.py"], - srcs_version = "PY2AND3", - deps = [ - ":cross_tower_utils", - ":values", - "//tensorflow/python:array_ops", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "@six_archive//:six", - ], -) - cuda_py_test( - name = "cross_tower_ops_test", - srcs = ["cross_tower_ops_test.py"], + name = "cross_device_ops_test", + srcs = ["cross_device_ops_test.py"], additional_deps = [ ":combinations", - ":cross_tower_ops", ":multi_worker_test_base", ":mirrored_strategy", - ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -697,37 +580,6 @@ cuda_py_test( ], ) -py_library( - name = "input_ops", - srcs = ["input_ops.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - ], -) - -cuda_py_test( - name = "input_ops_test", - srcs = ["input_ops_test.py"], - additional_deps = [ - ":input_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:errors", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:io_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python:util", - ], - tags = [ - "no_pip", - ], -) - py_library( name = "keras_test_lib", testonly = 1, @@ -757,8 +609,6 @@ cuda_py_test( "no_oss", # TODO(b/117919883): Fix python error. "no_pip", "no_windows_gpu", - # TODO(b/118815591): Re-enable this test in guitar.) - "noguitar", "notsan", ], ) @@ -769,7 +619,6 @@ py_library( srcs = ["metrics_v1_test.py"], deps = [ ":combinations", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index b311644cb22898df082b0c803d1a8960fe159c98..31bd0e996a247a2fc01405fb3b8172a40853d698 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -43,7 +43,9 @@ class CheckpointUtilsWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], in_replica_mode=[True, False], mode=["graph"])) def testInitFromCheckpoint(self, distribution, in_replica_mode): @@ -69,7 +71,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, g) + distribution.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index d9339f8f75acda3695d33c55409e921a9627bac7..906377b7395a520780e485461b83298320ebdcb3 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,21 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib # TODO(yuefengz): support in-graph replication. -class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): +class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. It is similar to the MirroredStrategy but it uses collective ops for @@ -53,10 +54,20 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): num_gpus_per_worker: number of local GPUs or GPUs per worker, the default is 0 meaning CPU only. """ + super(CollectiveAllReduceStrategy, self).__init__( + CollectiveAllReduceExtended(self, num_gpus_per_worker)) + + +class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): + """Implementation of CollectiveAllReduceStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + distribute_lib.DistributionStrategyExtended.__init__( + self, container_strategy) self._num_gpus_per_worker = num_gpus_per_worker - self._initialize_local_worker(num_gpus_per_worker) + self._initialize_local_worker(container_strategy, num_gpus_per_worker) - def _initialize_local_worker(self, num_gpus_per_worker): + def _initialize_local_worker(self, container_strategy, num_gpus_per_worker): """Initializes the object for local training.""" self._is_chief = True self._num_workers = 1 @@ -68,10 +79,11 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): else: local_devices = ["/device:CPU:0"] - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( + self._collective_keys = cross_device_utils.CollectiveKeys() + super(CollectiveAllReduceExtended, self).__init__( + container_strategy, devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( num_workers=1, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys)) @@ -83,8 +95,8 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): logging.info("CollectiveAllReduceStrategy with local_devices = %r", local_devices) - def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, - task_type, task_id): + def _initialize_multi_worker(self, container_strategy, num_gpus_per_worker, + cluster_spec, task_type, task_id): """Initializes the object for multi-worker training.""" if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " @@ -94,8 +106,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len( - cluster_spec.as_dict().get("chief", [])) + self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker` or `chief` tasks can be found in " "`cluster_spec`.") @@ -112,10 +123,11 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): else: local_devices = [worker_device] - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( + self._collective_keys = cross_device_utils.CollectiveKeys() + super(CollectiveAllReduceExtended, self).__init__( + container_strategy, devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + cross_device_ops=cross_device_ops_lib.CollectiveAllReduce( num_workers=self._num_workers, num_gpus_per_worker=num_gpus_per_worker, collective_keys=self._collective_keys)) @@ -202,17 +214,35 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return mirrored_strategy._create_mirrored_variable( devices, _real_mirrored_creator, *args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec is None: + input_pipeline_id = 0 + else: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + input_context = distribute_lib.InputContext( + num_input_pipelines=self._num_workers, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + + return values.InputFunctionIterator( + input_fn, [(self._default_device, self._devices)], [input_context]) + + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the object. Args: @@ -229,8 +259,9 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): # If a `cluster_spec` is already passed in, do nothing here. # TODO(yuefengz): check `cluster_spec` is the same if this object has # already been initialized with a `cluster_spec`. - self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, - task_type, task_id) + self._initialize_multi_worker( + self._container_strategy(), self._num_gpus_per_worker, cluster_spec, + task_type, task_id) if not session_config: return @@ -271,11 +302,11 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): "/job:%s/task:%d" % (self._task_type, self._task_id)) @property - def between_graph(self): + def experimental_between_graph(self): return True @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -287,6 +318,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return self._is_chief @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._devices) * self._num_workers + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 19b59513d81c5b0cebf5e44aa66b110db86a91c8..eb2b859aa559dd0c72351a009149ffdcb3c96b7c 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -23,13 +23,18 @@ import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops @@ -54,8 +59,6 @@ class CollectiveAllReduceStrategyTestBase( self._run_options = config_pb2.RunOptions() self._run_options.experimental.collective_graph_key = 6 - self._sess_config = config_pb2.ConfigProto() - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -66,33 +69,37 @@ class CollectiveAllReduceStrategyTestBase( def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( num_gpus_per_worker=num_gpus) + session_config = config_pb2.ConfigProto() if task_type and task_id is not None: distribution.configure( - session_config=self._sess_config, + session_config=session_config, cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution._collective_keys = collective_keys - distribution._cross_tower_ops._collective_keys = collective_keys + distribution.extended._collective_keys = collective_keys + distribution.extended._cross_device_ops._collective_keys = collective_keys if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + return distribution, 'grpc://' + self._cluster_spec[task_type][ + task_id], session_config else: - return distribution, '' + return distribution, '', session_config def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ d.scope(): - l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + l = core.Dense(1, use_bias=False, + name='gpu_%d' % d.extended._num_gpus_per_worker) def loss_fn(x): y = array_ops.reshape(l(x), []) - constant_op.constant(1.) @@ -117,7 +124,7 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -127,7 +134,7 @@ class CollectiveAllReduceStrategyTestBase( with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -135,7 +142,7 @@ class CollectiveAllReduceStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True sess.run( @@ -154,7 +161,8 @@ class CollectiveAllReduceStrategyTestBase( return error_after < error_before def _test_complex_model(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) def model_fn(): """Mnist model with synthetic input.""" @@ -193,7 +201,7 @@ class CollectiveAllReduceStrategyTestBase( return train_op with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess: with d.scope(): train_op = d.call_for_each_replica(model_fn) @@ -204,10 +212,10 @@ class CollectiveAllReduceStrategyTestBase( return True def _test_variable_initialization(self, task_type, task_id, num_gpus): - distribution, master_target = self._get_test_object(task_type, task_id, - num_gpus) + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ distribution.scope(): @@ -222,7 +230,7 @@ class CollectiveAllReduceStrategyTestBase( x = distribution.call_for_each_replica(model_fn) reduced_x = distribution.unwrap( distribution.reduce( - variable_scope.VariableAggregation.MEAN, x, + reduce_util.ReduceOp.MEAN, x, destinations='/cpu:0'))[0] x = distribution.unwrap(x)[0] @@ -237,9 +245,42 @@ class CollectiveAllReduceStrategyTestBase( reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class DistributedCollectiveAllReduceStrategyTest( - CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -267,7 +308,7 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, @@ -277,10 +318,30 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + # TODO(yuefengz): Update how we use num_gpus and required_gpus + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMakeInputFnIterator(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + # We use CPU as the device when num_gpus = 0 + devices_per_worker = max(1, num_gpus) + expected_values = [[i+j for j in range(devices_per_worker)] + for i in range(0, 100, devices_per_worker)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=3*devices_per_worker, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -321,20 +382,36 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_minimize_loss_graph(None, None, num_gpus) def testComplexModel(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) + def testMakeInputFnIterator(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 63a163e76cdd99c73399c657cbe9bc3d010369d2..f3ce547f4d0ffc8d507c77adb22293edf7c54373 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -168,6 +168,8 @@ def _augment_with_special_arguments(test_method): if GPU_TEST: self.skipTest("Test that doesn't require GPUs.") elif context.num_gpus() < required_gpus: + # TODO(priyag): Consider allowing tests in graph mode using soft + # placement. self.skipTest( "{} GPUs are not available for this test. {} GPUs are available". format(required_gpus, context.num_gpus())) @@ -335,17 +337,35 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) -# Note that we disable prefetching for testing since prefetching makes -# the input non-deterministic. +mirrored_strategy_with_one_cpu = NamedDistribution( + "Mirrored1CPU", + lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) +mirrored_strategy_with_one_gpu = NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), + required_gpus=1) mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2) +core_mirrored_strategy_with_one_cpu = NamedDistribution( + "CoreMirrored1CPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/cpu:0"])) +core_mirrored_strategy_with_one_gpu = NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "CoreMirroredCPUAndGPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/cpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_two_gpus = NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) @@ -377,8 +397,11 @@ def distributions_and_v1_optimizers(): """A common set of combination with DistributionStrategies and Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v1) @@ -387,7 +410,10 @@ def distributions_and_v2_optimizers(): """DistributionStrategies and V2 Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py similarity index 74% rename from tensorflow/contrib/distribute/python/cross_tower_ops_test.py rename to tensorflow/contrib/distribute/python/cross_device_ops_test.py index 6a9e8e00c02411d6486f30146f7f7d86ecd2fa9c..40410b90be7d9d9ed20fb4e696565cf79c044553 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -24,28 +24,28 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util -def _make_per_device(values, devices, regroup=False): - devices = cross_tower_ops_lib.get_devices_from(devices) +def _make_per_replica(values, devices, regroup=False): + devices = cross_device_ops_lib.get_devices_from(devices) assert len(values) == len(devices) - # We simulate the result of regroup called on PerDevice which strips the - # PerDevice wrapper if it has only one value. + # We simulate the result of regroup called on PerReplica which strips the + # PerReplica wrapper if it has only one value. if len(values) == 1 and regroup: with ops.device(devices[0]): placed_v = array_ops.identity(values[0]) @@ -56,7 +56,7 @@ def _make_per_device(values, devices, regroup=False): with ops.device(d): placed_v = array_ops.identity(v) index[d] = placed_v - return value_lib.PerDevice(index) + return value_lib.PerReplica(index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -66,7 +66,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib.get_devices_from(devices) + devices = cross_device_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) @@ -118,15 +118,15 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - def _testReductionAndBroadcast(self, cross_tower_ops, distribution): - devices = distribution.worker_devices + def _testReductionAndBroadcast(self, cross_device_ops, distribution): + devices = distribution.extended.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] - per_device = _make_per_device(values, devices) + per_replica = _make_per_replica(values, devices) mean = (len(devices) - 1.) / 2. values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = mean + 1. destination_mirrored = _fake_mirrored(1., devices) @@ -142,41 +142,43 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, - per_device, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, - per_device_2, + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_device, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, - per_device_2, + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), (per_device_2, d2)]), + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ]) self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), (per_device_2, d2)]), + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), _fake_mirrored(mean_2 * len(devices), d2) @@ -185,7 +187,7 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): # test broadcast() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + cross_device_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) @@ -194,62 +196,65 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): # combinations module so that we can pass in devices instead of a distribution # strategy. reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "DefaultReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "ReductionToCPUDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_cpu_device)), combinations.NamedObject( "AccumulateNCrossDeviceOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( accumulation_fn=math_ops.accumulate_n)), ], distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], mode=["graph", "eager"]) allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "AllReduce", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), combinations.NamedObject( "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 8, 0, 0)), combinations.NamedObject( "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), combinations.NamedObject( "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossDeviceOps( + cross_device_ops_lib.AllReduceCrossDeviceOps( "hierarchical_copy", 0, 100, 10)) ], - distribution=[combinations.mirrored_strategy_with_two_gpus], + distribution=[combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph", "eager"]) @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -257,16 +262,16 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "hierarchical_copy") self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossDeviceOps) + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) self.assertEqual(result._all_reduce_alg, "nccl") self.assertEqual(result._num_packs, 1) @@ -277,9 +282,9 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce( - per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + result = cross_device_ops_lib._simple_reduce( + per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -292,41 +297,42 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): @combinations.generate( combinations.combine( - cross_tower_ops_instance=[ + cross_device_ops_instance=[ combinations.NamedObject( "ReductionToOneDeviceCrossDeviceOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), combinations.NamedObject( "AllReduceCrossDeviceOps", - cross_tower_ops_lib.AllReduceCrossDeviceOps()) + cross_device_ops_lib.AllReduceCrossDeviceOps()) ], - aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], batch_reduce=[True, False], mode=["graph", "eager"], required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op, batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(aggregation, - [(per_device, devices)]) + result = cross_device_ops_instance.batch_reduce( + reduce_op, [(per_replica, devices)]) else: - result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) + result = cross_device_ops_instance.reduce( + reduce_op, per_replica, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if aggregation == vs.VariableAggregation.SUM: + if reduce_op == reduce_util.ReduceOp.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert aggregation == vs.VariableAggregation.MEAN + assert reduce_op == reduce_util.ReduceOp.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -353,49 +359,65 @@ class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" ] multi_worker_allreduce_combinations = combinations.combine( - cross_tower_ops=[ + cross_device_ops=[ combinations.NamedObject( "MultiWorkerAllReduce", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), combinations.NamedObject( "MultiWorkerAllReducePack", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), combinations.NamedObject( "MultiWorkerAllReduceAggregation", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), combinations.NamedObject( "MultiWorkerAllReduceMultipleSpecs", - cross_tower_ops_lib.MultiWorkerAllReduce( + cross_device_ops_lib.MultiWorkerAllReduce( worker_devices, 2, [("pscpu/pscpu", 2, 100), ("xring", 2, -1)], 0, 0, 0)), ], distribution=[ combinations.NamedDistribution( "MirroredCPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0), required_gpus=0), combinations.NamedDistribution( "Mirrored1GPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1), required_gpus=1), combinations.NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2), + required_gpus=2), + # pylint: disable=g-long-lambda + combinations.NamedDistribution( + "CoreMirroredCPU", + lambda: mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=0), + required_gpus=0), + combinations.NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=1), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=2), required_gpus=2), ], mode=["graph"]) @combinations.generate(multi_worker_allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def testReductionAndBroadcast(self, cross_device_ops, distribution): distribution.configure(cluster_spec={ "worker": ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] }) with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) + self._testReductionAndBroadcast(cross_device_ops, distribution) class MultiWorkerCollectiveAllReduceTest( @@ -416,7 +438,7 @@ class MultiWorkerCollectiveAllReduceTest( MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + MultiWorkerCollectiveAllReduceTest.collective_key_base, instance_key_start=num_gpus * 100 + @@ -424,7 +446,7 @@ class MultiWorkerCollectiveAllReduceTest( instance_key_with_id_start=num_gpus * 10000 + MultiWorkerCollectiveAllReduceTest.collective_key_base) if local_mode: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 1, num_gpus, collective_keys=collective_keys) if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] @@ -432,7 +454,7 @@ class MultiWorkerCollectiveAllReduceTest( devices = ["/device:CPU:0"] return collective_all_reduce_ops, devices, "" else: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( 3, num_gpus, collective_keys=collective_keys) if num_gpus: devices = [ @@ -478,11 +500,11 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_device = _make_per_device(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices, regroup=True) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = np.array([mean[0] + 1.]) destination_mirrored = _fake_mirrored(1., devices) @@ -499,27 +521,27 @@ class MultiWorkerCollectiveAllReduceTest( for destinations in all_destinations: self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, - per_device, + reduce_util.ReduceOp.MEAN, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, - per_device_2, + reduce_util.ReduceOp.MEAN, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, - per_device, + reduce_util.ReduceOp.SUM, + per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( - vs.VariableAggregation.SUM, - per_device_2, + reduce_util.ReduceOp.SUM, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), sess) @@ -527,17 +549,17 @@ class MultiWorkerCollectiveAllReduceTest( # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), - (per_device_2, d2)]), + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), - (per_device_2, d2)]), + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py similarity index 75% rename from tensorflow/contrib/distribute/python/cross_tower_utils_test.py rename to tensorflow/contrib/distribute/python/cross_device_utils_test.py index d25964fa41adc7b1c9164a4ffe49c4c5532f76ac..6086eba0984782f5e85235142817569bee135df0 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for cross_tower_utils.""" +"""Tests for cross_device_utils.""" from __future__ import absolute_import from __future__ import division @@ -21,8 +21,8 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -43,7 +43,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self._assert_values_equal(total, result) @test_util.run_in_graph_and_eager_modes @@ -53,7 +53,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(total, result) @@ -62,7 +62,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result) @test_util.run_in_graph_and_eager_modes @@ -71,7 +71,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result) @@ -79,7 +79,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + self.assertTrue(cross_device_utils.contains_indexed_slices(t)) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_List(self): @@ -87,7 +87,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1])) @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_Tuple(self): @@ -95,27 +95,16 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1))) @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDevice(self): + def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDeviceMapOutput(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({ - "/gpu:0": value_lib.MapOutput([t0]), - "/cpu:0": value_lib.MapOutput([t1])}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( mode=["graph", "eager"], @@ -124,7 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): with ops.device("/cpu:0"): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self._assert_values_equal(t, result) @@ -139,7 +128,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( t, destination) self.assertIsInstance(result, ops.IndexedSlices) diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index a1355c0b09e51c18cc4f8967dfc2c472d63593b9..e17085628ba6d1dfc79839fd824801723f07a518 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -34,7 +34,7 @@ from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -63,7 +63,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -75,12 +77,12 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=True) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 018512ae5a22eaa7fb78a8c4e5918fec22eb8178..0f35657a8099523b6ba5b8f0a1a2f289c06b531a 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -45,11 +45,13 @@ from tensorflow.python.estimator import training as estimator_training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import session_manager + BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -202,10 +204,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={"x": DATA}, y=DATA, - batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync, shuffle=True) if eval_distribute: - eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync else: eval_batch_size = BATCH_SIZE eval_input_fn = self.dataset_input_fn( @@ -291,19 +293,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_distribute_cls=[ collective_all_reduce_strategy.CollectiveAllReduceStrategy, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy ], eval_distribute_cls=[ - None, mirrored_strategy.MirroredStrategy, + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - try: - train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) - except TypeError: - train_distribute = train_distribute_cls(num_gpus_per_worker=2) + train_distribute = train_distribute_cls( + num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls( @@ -324,10 +327,12 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, mode=["graph"], train_distribute_cls=[ mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, ], required_gpus=[0, 1])) def test_estimator_standalone_client(self, train_distribute_cls, @@ -407,6 +412,7 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy, ], required_gpus=[0, 1])) @@ -451,8 +457,15 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @combinations.generate( combinations.combine( mode=["graph"], - train_distribute_cls=[mirrored_strategy.MirroredStrategy], - eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], + eval_distribute_cls=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): @@ -508,7 +521,8 @@ class RunConfigTest(test.TestCase): "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=2))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -531,10 +545,12 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=2))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + eval_distribute=mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=2))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -547,26 +563,27 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=2))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): with self.assertRaises(ValueError): run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy())) + train_distribute=mirrored_strategy.CoreMirroredStrategy())) with self.assertRaises(ValueError): run_config_lib.RunConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(), + eval_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy())) + eval_distribute=mirrored_strategy.CoreMirroredStrategy())) def test_init_run_config_none_distribute_coordinator_mode(self): # We don't use distribute coordinator for local training. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) dc_training.init_run_config(config, {}) self.assertIsNone(config._distribute_coordinator_mode) @@ -574,7 +591,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertIsNone(config._distribute_coordinator_mode) # When `train_distribute` is not specified, don't use distribute @@ -590,7 +607,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertEqual(config._distribute_coordinator_mode, dc.CoordinatorMode.INDEPENDENT_WORKER) @@ -599,7 +616,7 @@ class RunConfigTest(test.TestCase): # `experimental.remote_cluster` is set use distribute coordinator with # STANDALONE_CLIENT mode. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( remote_cluster={"chief": ["fake_worker"]})) self.assertEqual(config._distribute_coordinator_mode, @@ -607,5 +624,15 @@ class RunConfigTest(test.TestCase): if __name__ == "__main__": + # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. + orig_init = session_manager.SessionManager.__init__ + + def new_init(*args, **kwargs): + kwargs.pop("recovery_wait_secs", None) + kwargs["recovery_wait_secs"] = 0.5 + orig_init(*args, **kwargs) + + session_manager.SessionManager.__init__ = new_init + with test.mock.patch.object(sys, "exit", os._exit): test.main() diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index c7036daa3e3321a29e4fc9ae30449fbf15b69b1b..0fd3acd045170c04ebdaa9c84d0cb7267a4bc68a 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -61,7 +61,6 @@ def get_input_datasets(use_bfloat16=False): # train dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() - train_ds = train_ds.shuffle(100) train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) train_ds = train_ds.batch(64, drop_remainder=True) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index f4c222f26c3f6501cd78a69dd6a6d9a442a6bd24..fba06283ce560390b9a408ac7ceb30bbe17a754b 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -25,23 +25,28 @@ import numpy as np import six from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import distribution_strategy_context as ds_context class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): @@ -64,7 +69,9 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -76,11 +83,11 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices)) + batch_size=batch_size // distribution.num_replicas_in_sync) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -136,44 +143,51 @@ class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): shutil.rmtree(self._model_dir) -class MirroredStrategyOptimizerV2Test(test.TestCase): +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model - def testKerasOptimizerWithUnequalInput(self): - if context.num_gpus() < 1: - self.skipTest('Not enough GPUs.') - def create_fn(device_id): +class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testKerasOptimizerWithUnequalInput(self, distribution): + def create_fn(): var = variables.Variable( 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. - loss = (device_id + 1) * var + loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) train_op = optimizer.minimize(loss, var_list=[var]) m = optimizer.get_slot(var, 'm') v = optimizer.get_slot(var, 'v') - return (var, m, v, train_op, optimizer.iteration) + return (var, m, v, train_op, optimizer.iterations) devices = ['/device:GPU:0', '/device:CPU:0'] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - (var, m, v, op, counter) = dist.call_for_each_replica( - create_fn, dist.worker_device_index, run_concurrently=False) + with distribution.scope(): + (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([0, 0, 0], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) - train_op = dist.unwrap(op) + train_op = distribution.unwrap(op) self.evaluate(train_op) # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 m_val = [1.2, 1.2, 1.2] @@ -181,7 +195,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 @@ -189,7 +203,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) @@ -198,12 +212,12 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( var_val, self.evaluate( - [dist.read_var(var), + [distribution.read_var(var), var.get(devices[0]), var.get(devices[1])])) self.assertAllClose([1, 1, 1], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) @@ -214,7 +228,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( m_val, self.evaluate( - [dist.read_var(m), + [distribution.read_var(m), m.get(devices[0]), m.get(devices[1])])) # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 @@ -222,16 +236,50 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): self.assertAllClose( v_val, self.evaluate( - [dist.read_var(v), + [distribution.read_var(v), v.get(devices[0]), v.get(devices[1])])) self.assertAllClose([2, 2, 2], self.evaluate([ - dist.read_var(counter), + distribution.read_var(counter), counter.get(devices[0]), counter.get(devices[1]) ])) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): + + with self.cached_session(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit( + inputs, + targets, + epochs=1, + batch_size=2, + verbose=0, + validation_data=(inputs, targets)) + model.evaluate(inputs, targets) + model.predict(inputs) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index eccff1d9f57a6753a9c4ed745931b3108329b2a6..29d85fe971ff291df9e9ddf74c0082393bf55ba6 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -24,9 +24,9 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op @@ -47,7 +47,6 @@ _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 -_TOLERANCE = 1e-5 # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as @@ -213,10 +212,77 @@ def multi_input_output_model(): return model +def get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': 1, + 'shuffle': False, + } + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': 1, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + strategies = [combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 combinations.tpu_strategy_one_step] @@ -225,7 +291,9 @@ def strategy_minus_tpu_combinations(): distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], mode=['graph']) @@ -245,7 +313,15 @@ def strategy_and_optimizer_combinations(): mode=['graph']) -class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): +def strategy_and_inputs(): + return combinations.combine( + distribution=strategies, + use_numpy=[True, False], + mode=['graph']) + + +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -253,17 +329,18 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) - self._dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() if os.path.isdir(self._base_dir): gfile.DeleteRecursively(self._base_dir) - def test_train_functional_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_functional_with_distribution_strategy(self, distribution): keras_model = simple_functional_model() keras_model.compile( loss='categorical_crossentropy', @@ -271,8 +348,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist, - eval_distribute=dist) + train_distribute=distribution, + eval_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -286,9 +363,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_train_sequential_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_sequential_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -296,7 +376,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -310,7 +390,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() def train_input_fn(): @@ -340,14 +425,14 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): output_dict)).batch(16) self.do_test_multi_inputs_multi_outputs_with_input_fn( - train_input_fn, eval_input_fn) + distribution, train_input_fn, eval_input_fn) - def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, - eval_input_fn): + def do_test_multi_inputs_multi_outputs_with_input_fn( + self, distribution, train_input_fn, eval_input_fn): config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=self._dist) + train_distribute=distribution) with self.cached_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) @@ -357,9 +442,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) self.assertLess(eval_results['loss'], baseline_eval_results['loss']) - def test_keras_optimizer_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -367,7 +455,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator(keras_model=keras_model, config=config) @@ -392,82 +480,133 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # Verify that the numpy value is copied to the variable. self.assertAllEqual(x, val) - def test_calculating_batch_params(self): - # This verifies that we calculate the number of steps when the batch size - # is specified. + @combinations.generate(strategy_combinations()) + def test_calculating_input_params_no_steps_no_batch_size(self, distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - # The number of replicas is equal to 3. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0', - '/device:GPU:1']) - - with self.assertRaisesRegexp(ValueError, 'Please specify a batch_size ' - 'that is smaller than'): - # The batch size(128) is larger than the number of input - # samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 128, - strategy) - - with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' - 'of replicas'): - # The batch size(32) * num_replicas(3) is 96 which is greater than the - # number of input samples(64). - distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - - # The number of replicas now is equal to 2. - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - # 32 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 32, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(32) which is 2. The number of steps(1) is the ratio of - # number of batches(2) to the number of replicas(2). + # Input samples of different sizes + input_20_samples = np.zeros((20, 3), dtype=np.float32) + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Default global batch size 32 for input with 64 samples run in 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # Computed global batch size 20 is lower than 32 if we pass less samples. + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_20_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 20 // replica_scale_factor) self.assertEqual(steps, 1) - # 16 is the batch size per replica. - steps = distributed_training_utils.get_input_batch_params(inputs, - 16, - strategy) - # The number of batches is the ratio of input samples(64) to - # batch size(16) which is 4. The number of steps(2) is the ratio of - # number of batches(4) to the number of replicas(2). + # Default global batch size 32 cannot be used with 63 samples. + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=None, batch_size=None) + + @combinations.generate(strategy_combinations()) + def test_calculating_input_params_with_steps_no_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed global batch size is correct for number of specified 1 step + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=1, batch_size=None) + self.assertEqual(batch_size, 64 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Computed global batch size is correct for number of specified 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=2, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) self.assertEqual(steps, 2) - def test_calculating_batch_size(self): + # All samples can not be consumed in specified number of steps + with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=2, batch_size=None) + + # This cases is different for different strategies due to the + # difference in supported batch size being global or per-replica. + if replica_scale_factor == 1: + # Computed global batch size is correct even if not sharadable + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=3, batch_size=None) + self.assertEqual(batch_size, 21) + self.assertEqual(steps, 3) + else: + # Computed global batch size can not be sharded across replicas + with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' + 'across the sync replicas'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=1, batch_size=None) + + @combinations.generate(strategy_combinations()) + def test_calculating_input_params_no_steps_with_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + with self.cached_session(): - # 64 is the number of input samples. - inputs = np.zeros((64, 3), dtype=np.float32) - targets = np.zeros((64, 4), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=16) + self.assertEqual(batch_size, 16) + self.assertEqual(steps, 4 // replica_scale_factor) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=32) + self.assertEqual(batch_size, 32) + self.assertEqual(steps, 2 // replica_scale_factor) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=20) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=3) - model = get_model() - optimizer = gradient_descent.GradientDescentOptimizer(0.001) - loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - strategy._require_static_shapes = True - - model.compile(optimizer, loss, distribute=strategy) - iterator = model._distribution_standardize_user_data(inputs, - targets, - batch_size=None, - check_steps=True, - steps_name='steps', - steps=3) - - # The global batch size(21) across all replicas is the ratio of the input - # samples(64) to the steps(3). - # The batch size(10) per device is the ratio of the global batch size(21) - # to the number of replicas(2). - # The global batch size and batch size are rounded integer values. - self.assertEqual(10, distributed_training_utils.get_batch_dimension( - iterator._iterator)) + @combinations.generate(strategy_combinations()) + def test_calculating_input_params_with_steps_with_batch_size(self, + distribution): + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # No change to steps and batch size if both specified and feasible + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=5, batch_size=3) + self.assertEqual(batch_size, 3) + self.assertEqual(steps, 5) + + # Number of samples is less than global batch size * steps + with self.assertRaisesRegexp(ValueError, 'less than samples required'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=10, batch_size=13) @combinations.generate(strategy_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): @@ -541,9 +680,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, loss = 'mse' model.compile(optimizer, loss, distribute=distribution) - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - sample_weights = np.ones((10), np.float32) + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, steps_per_epoch=2, verbose=1) @@ -566,7 +705,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # `predict` a list that is equal in length to the number of model outputs. # In this test our model has two outputs and each element of `outs` # corresponds to all the samples of one of the model outputs. - self.assertEqual(2, len(outs)) + self.assertLen(outs, 2) # Each of the output samples have a dimension of 7. We should process all # the available input samples(6). self.assertAllEqual([6, 7], outs[0].shape) @@ -598,36 +737,33 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - loss = 'mse' - user_controlled_model = get_model() - user_controlled_optimizer = gradient_descent.GradientDescentOptimizer( - 0.001) - user_controlled_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - user_controlled_model.compile(user_controlled_optimizer, loss, - metrics=user_controlled_metrics, - distribute=distribution) + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) interleaved_model = get_model() - interleaved_optimizer = gradient_descent.GradientDescentOptimizer(0.001) - interleaved_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - interleaved_model.compile(interleaved_optimizer, loss, - metrics=interleaved_metrics, - distribute=distribution) + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) dataset = get_dataset(distribution) # Call fit with validation interleaved - interleaved_output = interleaved_model.fit(dataset, epochs=2, - steps_per_epoch=2, verbose=0, - validation_data=dataset, - validation_steps=2) + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) # Manually control the validation running after each epoch. user_controlled_output = [] for _ in range(2): user_controlled_model.fit( - dataset, epochs=1, steps_per_epoch=2, verbose=0) + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) user_controlled_output.append( user_controlled_model.evaluate(dataset, steps=2)) @@ -641,16 +777,20 @@ class TestDistributionStrategyWithDatasets(test.TestCase, # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not # tuples or dict. - def test_fit_with_tuple_and_dict_dataset_inputs(self): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): with self.cached_session(): model = multi_input_output_model() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -723,35 +863,48 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.evaluate(dataset, steps=2, verbose=1) model.predict(dataset, steps=2) - def test_dataset_input_shape_validation(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) + model.compile(optimizer, loss, distribute=distribution) - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) + dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - # Wrong input shape - inputs = np.zeros((10, 5), dtype=np.float32) + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, - 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( @@ -773,7 +926,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - def test_learning_phase_value(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. @@ -787,42 +945,50 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.005) loss = 'mse' metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync inputs = np.ones((10, 1), dtype=np.float32) targets = np.ones((10, 1), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat().batch(8) + dataset = dataset.repeat().batch(batch_size) hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) self.assertAlmostEqual(hist.history['acc'][0], 0, 0) model.set_weights(initial_weights) - evaluate_output = model.evaluate(dataset, steps=20) - self.assertAlmostEqual(evaluate_output[1], 1, 0) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) inputs = np.ones((10, 1), dtype=np.float32) predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) - predict_dataset = predict_dataset.repeat().batch(5) + + predict_dataset = predict_dataset.repeat().batch(batch_size) output = model.predict(predict_dataset, steps=10) - # `predict` runs for 10 steps and in each step you process 100 samples. - ref_output = np.ones((100, 1), dtype=np.float32) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) self.assertArrayNear(output, ref_output, 1e-1) class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - def test_validating_dataset_input_tensors_with_shape_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def test_validating_dataset_input_tensors_with_shape_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. @@ -831,17 +997,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. @@ -850,21 +1020,23 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_unsupported_features(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) # Test with validation split with self.assertRaisesRegexp( @@ -899,18 +1071,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) - def test_calling_with_unsupported_predefined_callbacks(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) def schedule(_): return 0.001 @@ -933,11 +1108,17 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) -class TestDistributionStrategyWithLossMasking(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. - def test_masking(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) @@ -946,12 +1127,9 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -1018,26 +1196,39 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_replicas + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) self.assertEqual(history.history['binary_accuracy'], [1.0]) - @combinations.generate(strategy_combinations()) - def test_correctness(self, distribution): + @combinations.generate(strategy_and_inputs()) + def test_correctness(self, distribution, use_numpy): with self.cached_session(): + tolerance = 1e-5 + + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy)): + # TODO(b/119257215): use the default one once the flakyness is fixed. + tolerance = 1e-4 + keras.backend.set_image_data_format('channels_last') - num_samples = 10000 np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) - # Train and predict datasets are created with the same input numpy arrays. + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 x_train = np.random.rand(num_samples, 1) y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and @@ -1051,52 +1242,38 @@ class TestDistributionStrategyCorrectness(test.TestCase, initial_weights = model.get_weights() def fit_and_predict(with_distribution=None): + # We have initialized the model to the same weight for the distribution + # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), distribute=with_distribution) - batch_size = 64 - if with_distribution: - batch_size //= with_distribution.num_replicas - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, - y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - # We have initialized the model to the same weight for the distribution - # and non-distribution run. If you want to initialize the model to - # random weights for each run, you need to run the model through the - # entire dataset at least once to ensure that the weights converge to - # the same value. - model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict)) + model.fit(**training_inputs) + eval_result = model.evaluate(**eval_inputs) weights = model.get_weights() - x_predict = [[1.], [2.], [3.], [4.]] - predict_batch_size = 4 - if with_distribution: - predict_batch_size //= with_distribution.num_replicas - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, distribution) - predict_result = model.predict(predict_dataset, steps=1) - - return weights, predict_result - - wts_with_ds, predict_with_ds = fit_and_predict( + predict_result = model.predict(**predict_inputs) + + return weights, eval_result, predict_result + + wts_with_ds, eval_with_ds, predict_with_ds = fit_and_predict( with_distribution=distribution) - wts_without_ds, predict_without_ds = fit_and_predict( + wts_without_ds, eval_without_ds, predict_without_ds = fit_and_predict( with_distribution=None) - # Verify that the weights are the same within some limits of tolerance. + # Verify that the weights, eval results, predict outputs are the same + # within some limits of tolerance. self.assertAllClose( - wts_with_ds, wts_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) - # Verify that the predicted outputs are the same within some limits of - # tolerance. + wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) self.assertAllClose( - predict_with_ds, predict_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) - - -# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) + self.assertAllClose( + predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 9e1a7ad3932e3e8b79c70f1c07a241dcf52564f1..8ac659abe96370b751ed1556cc699fe20788a0fd 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -72,14 +72,14 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() -# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using -# ReplicaLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph"]) @@ -100,25 +100,26 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, inputs) + metric_fn, args=inputs) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) ctx = distribution.run_steps_on_dataset( - step_fn, iterator, iterations=distribution.steps_per_run) + step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] # In each run, we run multiple steps, and each steps consumes as many # batches as number of replicas. batches_per_update = ( - distribution.num_replicas * distribution.steps_per_run) + distribution.num_replicas_in_sync * + distribution.extended.steps_per_run) else: value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_replicas" with "1". - batches_per_update = distribution.num_replicas + # replace "distribution.num_replicas_in_sync" with "1". + batches_per_update = distribution.num_replicas_in_sync self.evaluate(iterator.initializer) self.evaluate(distribution.initialize()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 165732d578fd25b1ef631efc5827fd636427c7c8..e77d3d455b0a79b2fac6a458c3aa009ff5c2f780 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -22,10 +22,10 @@ from absl.testing import parameterized import numpy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -64,11 +64,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -111,7 +110,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): return distribution.group( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): with self.cached_session() as sess: @@ -159,11 +158,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -221,7 +219,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, update_ops_in_cross_replica_mode): """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): - num_replicas = len(distribution.worker_devices) + num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_replicas, @@ -229,17 +227,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - # Make sure prefetching is disabled since that makes the - # specific input on each device to be non deterministic, and - # this test relies on specific input being on each device. - if isinstance(distribution, mirrored_strategy.MirroredStrategy): - self.assertFalse(distribution._prefetch_on_device) - - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=batchnorm.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) @@ -295,7 +286,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + @@ -331,11 +324,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() - def step_fn(ctx, x, y): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, x, y, run_concurrently=False)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -369,10 +361,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 # with sum loss reduction, or 10.6 with mean. if loss_reduction == losses_impl.Reduction.SUM: - # Note that the "distribution.num_replicas" factor will go away once - # we split the input across replicas, instead of pulling a complete + # Note that the "distribution.num_replicas_in_sync" factor will go away + # once we split the input across replicas, instead of pulling a complete # batch of input per replica. - self.assertNear(weight, 2 + 21.2 * distribution.num_replicas, 0.0001) + self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync, + 0.0001) else: # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) @@ -412,21 +405,21 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( - name="replica_loss_agg", + name="replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) - def step_fn(output_context, *inputs): + def step_fn(output_context, inputs): (train_op, loss) = distribution.call_for_each_replica( - model_fn, output_context, *inputs, run_concurrently=False) + model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( - name="cross_replica_loss_agg", + name="cross_replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( - name="cross_replica_loss_noagg", + name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) @@ -434,16 +427,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): initial_loss = lambda: constant_op.constant(1e7) - # Initial values corresponding to aggregated losses are just single - # tensors. But for non aggregated losses, we need to have initial + # Initial values corresponding to reduced losses are just single + # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `broadcast` followed by `unwrap` # gives us the desired initial value structure. initial_loop_values = { - "replica_loss_agg": initial_loss(), - "cross_replica_loss_agg": initial_loss(), - "cross_replica_loss_noagg": + "replica_loss_reduced": initial_loss(), + "cross_replica_loss_reduced": initial_loss(), + "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } ctx = distribution.run_steps_on_dataset( @@ -453,17 +446,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_replica_loss_noagg"], - aggregated=False, distribution=distribution) - return (ctx.run_op, ctx.last_step_outputs["replica_loss_agg"]) + loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], + reduced=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) self.evaluate(distribution.initialize()) if not context.executing_eagerly(): @@ -488,17 +481,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing) - def _verify_loss_output(self, initial_loss, loss_output, aggregated, + def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): - if not aggregated: - self.assertEqual(distribution.num_replicas, - len(distribution.unwrap(loss_output))) + if not reduced: + self.assertLen(distribution.unwrap(loss_output), + distribution.num_replicas_in_sync) loss_output = distribution.reduce( - aggregation=variables_lib.VariableAggregation.MEAN, - value=loss_output, destinations="/device:CPU:0") + reduce_util.ReduceOp.MEAN, loss_output, destinations="/device:CPU:0") unwrapped_output = distribution.unwrap(loss_output) - self.assertEqual(1, len(unwrapped_output)) + self.assertLen(unwrapped_output, 1) loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.shape, loss_tensor.shape) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index c23de0694984076b1c9a8da45219436fc38cd286..a3bcc8db88f9466811aa15d37e14a22eb5ce485e 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -12,300 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class MirroredStrategy implementing DistributionStrategy.""" +"""Contrib version of MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -from functools import partial -import threading +import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import shared_variable_creator -from tensorflow.contrib.distribute.python import values -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import coordinator -from tensorflow.python.training import device_util +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import values from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.util import nest -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -@contextlib.contextmanager -def _enter_graph(g): - if context.executing_eagerly(): - with g.as_default(), context.eager_mode(): - yield - else: - with g.as_default(): - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) - return cpu_device.to_string() - - -class _RequestedStop(Exception): - pass - - -# _call_for_each_replica and _reduce_non_distributed_value are not members of -# MirroredStrategy so that they are generally not allowed to use anything -# specific to MirroredStrategy and thus can be shared with other distribution -# strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_replica. -def _call_for_each_replica(distribution, fn, *args, **kwargs): - """Run `fn` in separate threads, once per replica/worker device. - - Args: - distribution: the DistributionStrategy object. - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all replicas. - - Raises: - RuntimeError: If fn() calls get_replica_context().merge_call() a different - number of times from the available devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(distribution.worker_devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredReplicaThread( # pylint: disable=protected-access - distribution, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredReplicaThread - # (`MRT`) threads. The execution waits until - # `MRT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_replica_context().merge_call()` is called. If `fn` is - # complete, then `MRT.done` is set to True. Otherwise, arguments - # of `get_replica_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_replica_context().merge_call` are then set to `MRT.merge_result`. - # Each such `get_replica_context().merge_call` call returns the - # `MRT.merge_result` for that thread when `MRT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some replicas made a different number of " - "replica_context().merge_call() calls.") - # get_replica_context().merge_call() case - merge_args = values.regroup({t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MRT when we call merge_fn - # to ensure that if we have opened a name scope in the MRT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MRT and assume it is - # the same for all other MRTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) - - -def _reduce_non_distributed_value(distribution, aggregation, value, - destinations): - """Reduce a non-DistributedValue `value` to `destinations`.""" - if isinstance(value, values.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " - "`_reduce_non_distributed_value`, which is not allowed.") - - # If the same value is present on all replicas then the PerDevice value will - # be a single value. We also handle the case when `value` is a single value - # and equal to 0. - if value == 0: - return 0 - # If the aggregation type is MEAN or ONLY_FIRST_REPLICA, then this - # essentially means that the same value should be on all destinations. - if aggregation in ( - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA): - return value - - cross_tower_ops_lib.validate_destinations(destinations) - # We do not support an aggregation type of SUM if the value is the same across - # all replicas. We call this as part of assign functions for MirroredVariables - # and summing up identical values across replicas is not clearly defined. - if (len(distribution.worker_devices) != 1 or - not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given aggregation %s." % (value, aggregation)) - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - - -def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Get synchronization value - synchronization = kwargs.get("synchronization", - variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are replica local. - is_replica_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_replica_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in ( - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - - if is_replica_local: - result = values.ReplicaLocalVariable( - index, index[devices[0]], aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - if v in l: - l.remove(v) - g.add_to_collections(collections, result) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) - - return result +# pylint: disable=protected-access,invalid-name +_call_for_each_replica = mirrored_strategy._call_for_each_replica +_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value +_create_mirrored_variable = mirrored_strategy._create_mirrored_variable +CoreMirroredStrategy = mirrored_strategy.MirroredStrategy +CoreMirroredExtended = mirrored_strategy.MirroredExtended +# pylint: enable=protected-access,invalid-name class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices and machines. + *** contrib version *** + This strategy uses one replica per device and sync replication for its multi-GPU version. @@ -348,8 +81,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): specified. cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not set, the `configure` method will try to find the best one. - prefetch_on_device: optional boolean to specify whether to prefetch input - data to devices. auto_shard_dataset: whether to auto-shard the dataset when there are multiple workers. cross_tower_ops: Deprecated alias for `cross_device_ops`. @@ -360,482 +91,62 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_device_ops=None, - prefetch_on_device=None, auto_shard_dataset=False, cross_tower_ops=None): - super(MirroredStrategy, self).__init__() - assert not (cross_device_ops and cross_tower_ops) - self._cross_tower_ops = cross_device_ops or cross_tower_ops - self._prefetch_on_device = prefetch_on_device - self._auto_shard_dataset = auto_shard_dataset - # Remember num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") - if num_gpus is not None: - self._num_gpus = num_gpus - else: - self._num_gpus = num_gpus_per_worker - - self._initialize_local(self._num_gpus, devices) - - def _initialize_local(self, num_gpus, devices): - """Initializes the object for local training.""" - self._cluster_spec = None - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: - if num_gpus is None: - num_gpus = context.num_gpus() - if num_gpus == 0: - devices = ["/device:CPU:0"] - else: - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") - self._num_gpus = num_gpus - # TODO(yuefengz): consider setting the default device. - - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)}) - - def _initialize_multi_worker(self, num_gpus, cluster_spec): - """Initializes the object for multi-worker training.""" - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._cluster_spec = cluster_spec - - self._workers = [] - for job in ["chief", "worker"]: - for task in range(len(cluster_spec.as_dict().get(job, []))): - self._workers.append("/job:%s/task:%d" % (job, task)) - if num_gpus is None: - raise ValueError("`num_gpus` is required if `cluster_spec` is given.") - if num_gpus > 0: - self._worker_device_map = { - worker: [ - device_util.canonicalize(worker + "/device:GPU:%d" % gpu) - for gpu in range(num_gpus) - ] for worker in self._workers - } - else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, "/device:CPU:0")] - for worker in self._workers - } + num_gpus = num_gpus_per_worker + extended = MirroredExtended(self, devices, num_gpus, + cross_device_ops or cross_tower_ops, + auto_shard_dataset) + super(MirroredStrategy, self).__init__(extended) - devices = nest.flatten(self._worker_device_map) - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. - self._default_device = self._workers[0] +class MirroredExtended(CoreMirroredExtended): + """Implementation of (contrib) MirroredStrategy.""" - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice( - {d: i for i, d in enumerate(devices)}) + def __init__(self, + container_strategy, + devices=None, + num_gpus_per_worker=None, + cross_device_ops=None, + auto_shard_dataset=False): + super(MirroredExtended, self).__init__( + container_strategy, devices, num_gpus_per_worker, cross_device_ops) + self._auto_shard_dataset = auto_shard_dataset - def _create_variable(self, next_creator, *args, **kwargs): - """Create a mirrored variable. See `DistributionStrategy.scope`.""" - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch. - def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on replicas with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - def initial_value_fn(device=d): - if context.executing_eagerly(): - init_value = index[devices[0]].value() - return array_ops.identity(init_value) - else: - with ops.device(device): - init_value = index[devices[0]].initial_value - return array_ops.identity(init_value) - kwargs["initial_value"] = initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - # Don't record operations (e.g. other variable reads) during - # variable creation. - with tape.stop_recording(): - v = next_creator(*args, **kwargs) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index + This implementation is different than the one in + `tf.distribute.MirroredStrategy` for purposes of backward compatibility. + We treat the incoming dataset's batch size as per replica batch size. - return _create_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + Args: + dataset: `tf.data.Dataset` for input. + Returns: + An `InputIterator` which returns inputs for each step of the computation. + """ + if self._cluster_spec: + worker_device_pairs = self._worker_devices + else: + worker_device_pairs = [("/job:localhost", self._devices)] + return values.DatasetIterator(dataset, worker_device_pairs) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): if self._cluster_spec: return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device, self._auto_shard_dataset) + functools.partial(self._call_dataset_fn, dataset_fn), + self._worker_devices, + auto_shard=self._auto_shard_dataset) else: - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) - for (name, output) in ctx.last_step_outputs.items(): - # Convert all outputs to tensors, potentially from `DistributedValues`. - ctx.last_step_outputs[name] = self.unwrap(output) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerDevice container. - if aggregation is variables_lib.VariableAggregation.NONE: - last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerDevice) - else: - assert len(output) == 1 - last_step_tensor_outputs_dict[name] = output[0] - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _broadcast(self, tensor, destinations): - # TODO(josh11b): In eager mode, use one thread per device, or async mode. - return self._get_cross_tower_ops().broadcast(tensor, destinations or - self._devices) - - def _call_for_each_replica(self, fn, *args, **kwargs): - return _call_for_each_replica(self, fn, *args, **kwargs) - - def map(self, map_over, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - index = {} - for i, m in enumerate(map_over): - d = self._devices[i % len(self._devices)] - with ops.device(d): - l = index.get(d, []) - l.append(fn(m, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - index[d] = l - # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput - # in addition to PerDevice data. - return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) - - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - del task_type, task_id - - if session_config: - session_config.isolate_session_state = True - - if cluster_spec: - self._initialize_multi_worker(self._num_gpus, cluster_spec) - - if self._cross_tower_ops is None: - if self._cluster_spec: - # It currently cannot detect the toplogy of remote workers. So we - # hard-code the multi-worker all-reduce algorithm for now. - if len(self._workers) == 1: - # The default is "nccl". - self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossDeviceOps() - else: - # The default is hierarchical reduce and broadcast. - self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( - self._workers, self._num_gpus) - else: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) - - def _get_cross_tower_ops(self): - if self._cross_tower_ops is None: - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps()) - return self._cross_tower_ops - - def _reduce(self, aggregation, value, destinations): - assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerDevice or Mirrored - # values. For example, the same value could be present on all replicas in - # which case `value` would be a single value or value could be 0. - return _reduce_non_distributed_value(self, aggregation, value, - destinations) - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - value = value.get(self._devices[0]) - if isinstance(value, (int, float)): - return value - return self.broadcast(value, destinations) - return self._get_cross_tower_ops().reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._devices[0]), d) - for v, d in value_destination_pairs] - return self._get_cross_tower_ops().batch_reduce(aggregation, - value_destination_pairs) - - def _update(self, var, options, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - assert isinstance(var, values.DistributedVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - assert isinstance(colocate_with, list) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - # TODO(josh11b): In eager mode, use one thread per device. - updates = {} - for d in colocate_with: - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - updates[d] = fn(*values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def read_var(self, replica_local_var): - """Read the aggregate value of a replica-local variable.""" - if isinstance(replica_local_var, values.ReplicaLocalVariable): - return replica_local_var._get_cross_replica() # pylint: disable=protected-access - assert isinstance(replica_local_var, values.Mirrored) - return array_ops.identity(replica_local_var.get()) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_device_set: - return [val.get(device=d) for d in self._devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] - - def value_container(self, val): - return values.value_container(val) - - @property - def num_replicas(self): - return len(self._devices) + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), self._devices) + # TODO(priyag): Delete this once all strategies use global batch size. @property - def num_replicas_in_sync(self): - return len(self._devices) - - def _worker_device_index(self): - return self._device_index - - @property - def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._devices) - - @property - def parameter_devices(self): - return list(self._devices) - - @property - def between_graph(self): + def _global_batch_size(self): return False - - @property - def should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - def non_slot_devices(self, var_list): - del var_list - return list(self._devices) - - def _get_devices_from(self, colocate_with=None): - if colocate_with is None: - return self._devices - else: - return cross_tower_ops_lib.get_devices_from(colocate_with) - - class _MirroredReplicaThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, - **kwargs): - super(MirroredStrategy._MirroredReplicaThread, self).__init__() # pylint: disable=protected-access - self.coord = coord - self.distribution = dist - self.device = device - self.replica_id = dist.worker_devices.index(device) - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # ReplicaContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_replica_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - # pylint: disable=protected-access - self.context_mode = context.context()._eager_context.mode - if not context.context()._context_handle: - context.context()._initialize_handle_and_devices() - self.context_device_policy = ( - pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - context.context()._context_handle)) - self.graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] - self._captured_var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.replica_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "replica_%d/" % self.replica_id - - def run(self): - # pylint: disable=protected-access - self.graph._variable_creator_stack = self._variable_creator_stack - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - with self.coord.stop_on_exception(), \ - context.context()._mode(self.context_mode), \ - context.context().device_policy(self.context_device_policy), \ - _enter_graph(self.graph), \ - MirroredReplicaContext(self.distribution, self.replica_id), \ - ops.device(self.device), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._captured_var_scope, reuse=self.replica_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - -class MirroredReplicaContext(distribute_lib.ReplicaContext): - """ReplicaContext used in MirroredStrategy.call_for_each_replica(). - - Opened in `_MirroredReplicaThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.call_for_each_replica()`). - """ - - def _merge_call(self, fn, *args, **kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredReplicaThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result - - @property - def device(self): - distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index b8e7edaaf82804e6741cc4c94c44ed77189d7ad9..1027da857d3042e5f3699bf9e373c4be4d3a754a 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,14 +20,16 @@ from __future__ import print_function import sys +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.contrib.distribute.python import values -from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function @@ -35,7 +37,6 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core @@ -47,7 +48,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import distribution_strategy_context as ds_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib @@ -56,251 +57,240 @@ from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=["graph", "eager"])) +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): - def _get_distribution_strategy(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - if context.num_gpus() > 1: - devices = ["/device:GPU:0", "/device:GPU:1"] - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - return mirrored_strategy.MirroredStrategy(devices) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - soft_placement = not GPU_TEST - print("testMinimizeLossGraph soft_placement:", soft_placement) - self._test_minimize_loss_graph( - self._get_distribution_strategy(), soft_placement=soft_placement) - - def testMapReduce(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_device_index(self._get_distribution_strategy()) - - def testReplicaId(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_replica_id(self._get_distribution_strategy()) - - def testNumReplicas(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy().num_replicas) - - def testNumReplicasInSync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy(). - num_replicas_in_sync) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testRunRegroupError(self): - - def run_fn(device_id): - # Generates a list with different lengths on different devices. - # Will fail in _regroup() (if more than one device). - return list(range(device_id)) + def testNumReplicasInSync(self, distribution): + self.assertEqual(2, distribution.num_replicas_in_sync) - dist = self._get_distribution_strategy() - with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_replica(run_fn, dist.worker_device_index) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - @test_util.run_in_graph_and_eager_modes - def testReduceToCpu(self): - if not GPU_TEST: - self.skipTest("Not GPU test") + def testRunRegroupError(self, distribution): + def run_fn(): + replica_id = int(self.evaluate(_replica_id())) + # Generates a list with different lengths on different devices. + # Will fail in _regroup() (if more than one device). + return list(range(replica_id)) - def run_fn(device_id): - return device_id + with distribution.scope(), self.assertRaises(AssertionError): + distribution.extended.call_for_each_replica(run_fn) - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, + def testReduceToCpu(self, distribution): + with distribution.scope(): + result = distribution.extended.call_for_each_replica(_replica_id) + reduced = distribution.reduce( + reduce_util.ReduceOp.SUM, result, destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) + unwrapped = distribution.unwrap(reduced) self.assertEqual(1, len(unwrapped)) - expected = sum(range(len(dist.worker_devices))) + expected = sum(range(distribution.num_replicas_in_sync)) self.assertEqual(expected, self.evaluate(unwrapped[0])) - @test_util.run_in_graph_and_eager_modes - def testReduceOnlyFirstReplicaUpdates(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return constant_op.constant(3 + 5 * device_id) - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) - reduced = dist.reduce( - variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - self.assertEqual(3, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes() - def testReduceToMultipleDestinations(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - devices = ["/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, + def testMakeInputFnIterator(self, distribution): + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values) + + def testGlobalStepUpdate(self, distribution): + self._test_global_step_update(distribution) + + +def one_device_combinations(): + return combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_cpu, + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_cpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph", "eager"]) + + +class MirroredOneDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored1CPU", + lambda: mirrored_strategy.MirroredStrategy(["/device:CPU:0"]), + required_gpus=1), + combinations.mirrored_strategy_with_one_gpu, + combinations.NamedDistribution( + "CoreMirrored1CPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), + required_gpus=1), + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph", "eager"])) + def testReduceToMultipleDestinations(self, distribution): + with distribution.scope(): + reduced = distribution.extended.reduce_to( + reduce_util.ReduceOp.SUM, 1.0, destinations=["/device:CPU:0", "/device:GPU:0"]) - unwrapped = dist.unwrap(reduced) - self.assertEqual(2, len(unwrapped)) + unwrapped = distribution.unwrap(reduced) + self.assertLen(unwrapped, 2) self.assertEqual(1.0, self.evaluate(unwrapped[0])) + @combinations.generate(one_device_combinations()) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) -class MirroredStrategyVariableCreationTest(test.TestCase): + @combinations.generate(one_device_combinations()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - config = config_pb2.ConfigProto() - config.allow_soft_placement = True + @combinations.generate(one_device_combinations()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") - @test_util.run_in_graph_and_eager_modes(config=config) - def testSingleVariable(self): - self._skip_eager_if_gpus_less_than(1) +class MirroredStrategyVariableCreatorStackTest( + test.TestCase, parameterized.TestCase): + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) + def testCreatorStacksAreThreadLocal(self, distribution): + def model_fn(): + replica_id_str = str(self.evaluate(_replica_id())) + + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + replica_id_str + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + distribution.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = distribution.extended.call_for_each_replica(model_fn) + result = distribution.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEqual(expected, result) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyVariableCreationTest(test.TestCase): + + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of - # special variable_creator functions used by `dist.call_for_each_replica`. + # special variable_creator functions used by + # `distribution.extended.call_for_each_replica`. v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnnamedVariable(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("foo:0", result.name) + def testUnnamedVariable(self, distribution): def model_fn(): v = variable_scope.variable(1.0) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # Default name of "Variable" will be used. - self.assertEquals("Variable:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariables(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("Variable:0", result.name) + def testMultipleVariables(self, distribution): def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals("foo" + str(i) + ":0", v.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariablesWithSameCanonicalName(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("foo" + str(i) + ":0", v.name) + def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals(4, len(result)) - self.assertEquals("foo/bar:0", result[0].name) - self.assertEquals("foo_1/bar:0", result[1].name) - self.assertEquals("foo_1/bar_1:0", result[2].name) - self.assertEquals("foo/bar_1:0", result[3].name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testVariableWithSameCanonicalNameAcrossThreads(self): - self._skip_eager_if_gpus_less_than(1) - - def model_fn(device_id): - v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v + self.assertEqual(4, len(result)) + self.assertEqual("foo/bar:0", result[0].name) + self.assertEqual("foo_1/bar:0", result[1].name) + self.assertEqual("foo_1/bar_1:0", result[2].name) + self.assertEqual("foo/bar_1:0", result[3].name) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) + def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) + ds_context.get_replica_context().merge_call(lambda _: _) + return v - with dist.scope(): - result = dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. - self.assertEquals("foo_0:0", result.name) + self.assertEqual("foo_0:0", result.name) - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithLayers(self): - self._skip_eager_if_gpus_less_than(1) + def testWithLayers(self, distribution): def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) @@ -308,17 +298,14 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - ds = dist.distribute_dataset( + ds = distribution.distribute_dataset( lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) if context.executing_eagerly(): iterator = ds.make_one_shot_iterator() @@ -328,27 +315,23 @@ class MirroredStrategyVariableCreationTest(test.TestCase): features = iterator.get_next() - with dist.scope(): - result = dist.call_for_each_replica( - model_fn, features, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica( + model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) self.assertIsInstance(bias, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) + def testWithVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -362,37 +345,31 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): v = variable_scope.variable(1.0, name="var-main0") - self.assertEquals("var-main0:0", v.name) + self.assertEqual("var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("var0:0", v0.name) + self.assertEqual("var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("common/var1:0", v1.name) + self.assertEqual("common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertEqual("common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithGetVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) + def testWithGetVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -404,33 +381,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): with variable_scope.variable_scope("main"): v = variable_scope.get_variable("var-main0", [1]) - self.assertEquals("main/var-main0:0", v.name) + self.assertEqual("main/var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var0:0", v0.name) + self.assertEqual("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var1:0", v1.name) + self.assertEqual("main/common/var1:0", v1.name) self.assertIsInstance(v2, values.ReplicaLocalVariable) - self.assertEquals("main/common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, - v2.aggregation) + self.assertEqual("main/common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, + v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("main/common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testOnlyFirstReplicaUpdatesVariables(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("main/common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + def testOnlyFirstReplicaUpdatesVariables(self, distribution): def create_fn(): aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA v0 = variable_scope.variable( @@ -446,71 +418,73 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1 devices = ["/device:GPU:0", "/device:CPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - v0, v1 = dist.call_for_each_replica(create_fn, run_concurrently=False) + with distribution.scope(): + v0, v1 = distribution.extended.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) self.evaluate(v1.initializer) self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) + + def replica_id_plus_one(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) # Update using the assign_add member function. - def update_member_fn(device_id): - update0 = v0.assign_add(5.0 * (device_id + 1)) - update1 = v1.assign_add(7.0 * (device_id + 1)) + def update_member_fn(): + update0 = v0.assign_add(5.0 * replica_id_plus_one()) + update1 = v1.assign_add(7.0 * replica_id_plus_one()) return update0, update1 - update0a, update1a = dist.call_for_each_replica( - update_member_fn, dist.worker_device_index, run_concurrently=False) + update0a, update1a = distribution.extended.call_for_each_replica( + update_member_fn) # Update "sync on read" variable. - self.evaluate(dist.group(update0a)) + self.evaluate(distribution.group(update0a)) self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) # Writes are not synchronized for "sync on read" variables, # so device[1] can end up with a different value. self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) # Always reads from device 0. - self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1a)) + self.evaluate(distribution.group(update1a)) self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) # Writes are synchronized for v1, only the argument to assign_add on # device[0] is used. self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0 + 7.0, self.evaluate( + distribution.extended.read_var(v1))) # Update using state_ops.assign_add global function. - def update_state_ops_fn(device_id): - update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) - update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + def update_state_ops_fn(): + update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) + update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) return update0, update1 - update0b, update1b = dist.call_for_each_replica( - update_state_ops_fn, dist.worker_device_index, run_concurrently=False) - self.evaluate(dist.group(update0b)) + update0b, update1b = distribution.extended.call_for_each_replica( + update_state_ops_fn) + self.evaluate(distribution.group(update0b)) # Update "sync on read" variable. self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1b)) + self.evaluate(distribution.group(update1b)) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate( + distribution.extended.read_var(v1))) + + def testNoneSynchronizationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -519,12 +493,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "v", [1], synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testNoneSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -534,23 +504,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): name="v", synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable synchronization mode: Invalid for " "variable: v"): variable_scope.variable(1.0, name="v", synchronization="Invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -559,12 +521,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -574,55 +532,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testThreeDevices(self): - self._skip_eager_if_gpus_less_than(2) - - def model_fn(): - v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNonMatchingVariableCreation(self): - self._skip_eager_if_gpus_less_than(1) - + def testNonMatchingVariableCreation(self, distribution): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): names = values.DistributedValues({ "/device:CPU:0": "foo", "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_replica(model_fn, names, run_concurrently=False) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariable(self): - self._skip_eager_if_gpus_less_than(1) + _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) + def testReplicaLocalVariable(self, distribution): all_v_sum = {} all_v_mean = {} components_sum = {} components_mean = {} - def model_fn(device_id): + def model_fn(): + replica_id = self.evaluate(_replica_id()) v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -633,26 +564,22 @@ class MirroredStrategyVariableCreationTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) self.assertTrue(isinstance(v_mean, values.ReplicaLocalVariable)) - updates = [v_sum.assign_add(2.0 + device_id), - v_mean.assign(6.0 * device_id)] - all_v_sum[device_id] = v_sum - all_v_mean[device_id] = v_mean + updates = [v_sum.assign_add(2.0 + replica_id), + v_mean.assign(6.0 * replica_id)] + all_v_sum[replica_id] = v_sum + all_v_mean[replica_id] = v_mean c_sum = v_sum.get() c_mean = v_mean.get() - components_sum[device_id] = c_sum - components_mean[device_id] = c_mean + components_sum[replica_id] = c_sum + components_mean[replica_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( - dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False)) + distribution.extended.call_for_each_replica(model_fn)) # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) @@ -667,10 +594,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Apply updates self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)]) expected_sum = 0.0 expected_mean = 0.0 - for i, d in enumerate(dist.worker_devices): + for i, d in enumerate(distribution.extended.worker_devices): # Should see different values on different devices. v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) @@ -680,69 +607,125 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected = i * 6.0 self.assertEqual(expected, v_mean_value) expected_mean += expected - expected_mean /= len(dist.worker_devices) + expected_mean /= len(distribution.extended.worker_devices) # Without get(device), should return the value you get by # applying the reduction across all replicas (whether you use # read_var(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate( + distribution.extended.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate( + distribution.extended.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + # TODO(priyag): Update this test to work in eager mode as well. + def testDynamicRnnVariables(self, distribution): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + # Two variables are created by the RNN layer. + self.assertEqual(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = distribution.unwrap(v) + self.assertStartsWith(v1._op.name, "replica_1/") + + def testReplicaLocalVariableUpdate(self, distribution): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + return v_sum + + def update(var, value): + return var.assign(value) + + with distribution.scope(): + ret_v_sum = distribution.extended.call_for_each_replica(model_fn) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values before running the update ops. + self.assertEqual(1.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + update_ops = distribution.extended.update( + ret_v_sum, update, args=(5.0,), group=False) + self.evaluate(update_ops) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values after running the update ops. + self.assertEqual(5.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(10.0, self.evaluate(ret_v_sum)) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) +class MirroredStrategyNameScopeTest(test.TestCase): # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. - def testNameScope(self): + def testNameScope(self, distribution): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(1.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - self.assertEquals(2, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("main/foo/" + name + ":0", v0.name) - self.assertEquals("main/replica_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("main/foo/" + name + ":0", v0.name) + self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) - def testWithDefaultName(self): + def testWithDefaultName(self, distribution): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(2.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - self.assertEquals(2, len(result)) + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("foo/" + name + ":0", v0.name) - self.assertEquals("replica_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("foo/" + name + ":0", v0.name) + self.assertEqual("replica_1/foo/" + name + ":0", v1.name) # variable_scope.variable() respects name scopes when creating # variables. On the other hand variable_scope.get_variable() ignores name # scopes when creating variables. We test both methods of creating variables # to make sure that we have the same variable names in both cases. - def testNameScopeWithVariable(self): + def testNameScopeWithVariable(self, distribution): def in_cross_replica(_): c = variable_scope.variable(1.0, name="c") return c @@ -750,32 +733,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("main/a:0", a0.name) - self.assertEquals("main/a/replica_1:0", a1.name) - self.assertEquals("main/b:0", b0.name) - self.assertEquals("main/b/replica_1:0", b1.name) - self.assertEquals("main/foo/c:0", c0.name) - self.assertEquals("main/foo/c/replica_1:0", c1.name) - - def testNameScopeWithGetVariable(self): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("main/a:0", a0.name) + self.assertEqual("main/a/replica_1:0", a1.name) + self.assertEqual("main/b:0", b0.name) + self.assertEqual("main/b/replica_1:0", b1.name) + self.assertEqual("main/foo/c:0", c0.name) + self.assertEqual("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self, distribution): def in_cross_replica(_): c = variable_scope.get_variable("c", [1]) return c @@ -783,118 +762,78 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribution_strategy_context.get_replica_context().merge_call( - in_cross_replica) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("a:0", a0.name) - self.assertEquals("a/replica_1:0", a1.name) - self.assertEquals("b:0", b0.name) - self.assertEquals("b/replica_1:0", b1.name) - self.assertEquals("c:0", c0.name) - self.assertEquals("c/replica_1:0", c1.name) - - def testDynamicRnnVariables(self): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("a:0", a0.name) + self.assertEqual("a/replica_1:0", a1.name) + self.assertEqual("b:0", b0.name) + self.assertEqual("b/replica_1:0", b1.name) + self.assertEqual("c:0", c0.name) + self.assertEqual("c/replica_1:0", c1.name) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2)], + mode=["graph", "eager"])) +class MirroredThreeDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testThreeDevices(self, distribution): def model_fn(): - inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) - cell_fw = rnn_cell_impl.LSTMCell(300) - cell_bw = rnn_cell_impl.LSTMCell(300) - (outputs, _) = rnn.bidirectional_dynamic_rnn( - cell_fw, - cell_bw, - inputs, - dtype=dtypes.float32) - return outputs - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) - # Two variables are created by the RNN layer. - self.assertEquals(2, len(result)) - for v in result: - self.assertIsInstance(v, values.DistributedValues) - _, v1 = dist.unwrap(v) - self.assertStartsWith(v1.name, "replica_1/") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testReplicaLocalVariableUpdate(self): - with context.graph_mode(): - - def model_fn(): - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) - return v_sum - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]) - - def update(var, value): - return var.assign(value) - - with dist.scope(): - ret_v_sum = dist.call_for_each_replica(model_fn, run_concurrently=False) - update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) - - # Initialize variables. - self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values before running the update ops. - self.assertEquals(1.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(2.0, self.evaluate(ret_v_sum)) + v = variable_scope.variable(1.0, name="foo") + ds_context.get_replica_context().merge_call(lambda _: _) + return v - # Apply updates. - self.evaluate(update_ops) - # Assert that the aggregated value of the replica local vars is the sum - # of the individual values after running the update ops. - self.assertEquals(5.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(10.0, self.evaluate(ret_v_sum)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEqual("foo:0", result.name) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredVariableUpdateTest(test.TestCase): # The following tests check assign, assign_add and assign_sub on Mirrored # variables in replica and cross replica context. - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithoutAggregationType(self): + def testAssignMirroredVarReplicaContextWithoutAggregationType(self, + distribution): # Test that we always have an aggregation type set on the mirrored variable # if we assign to it in replica mode. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -904,23 +843,19 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "You must specify an aggregation method to update a " "MirroredVariable in Replica Context."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSum(self): - # Test that we don't reduce a non-per-device value with the "sum" + def testAssignMirroredVarReplicaContextWithSum(self, distribution): + # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -929,225 +864,184 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " - "with the given aggregation VariableAggregation.SUM."): - self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) + "with the given reduce op ReduceOp.SUM."): + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) - self.assertEquals(6.0, mirrored_var_result) + self.assertEqual(6.0, mirrored_var_result) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(0.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(0.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(5.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) # read_value == True mirrored_var_result = self.evaluate( mirrored_var.assign_add(6.0, read_value=True)) - self.assertEquals(7.0, mirrored_var_result) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(7.0, mirrored_var_result) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) # read_value == False self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(1.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(1.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(6.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(6.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarCrossDeviceContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(5.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) - self.assertEquals(3.0, mirrored_var_result) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(3.0, mirrored_var_result) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_replica_context().replica_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(4.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarReplicaContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) - self.assertEquals(4.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.0, self.evaluate(mirrored_var)) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def testAssignMirroredVarInitializer(self): + def testAssignMirroredVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1155,17 +1049,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.assertFalse(self.evaluate(mirrored_var.is_initialized())) self.evaluate(mirrored_var.initializer) self.assertTrue(self.evaluate(mirrored_var.is_initialized())) - def testAssignReplicaLocalVarInitializer(self): + def testAssignReplicaLocalVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1177,11 +1068,9 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica( + model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.assertFalse(self.evaluate(replica_local_var.is_initialized())) @@ -1189,17 +1078,14 @@ class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): self.assertTrue(self.evaluate(replica_local_var.is_initialized())) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class ReplicaLocalVariableAssignTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarSumAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarSumAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1207,19 +1093,16 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.SUM) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the SUM of each of # values on each of the replicas. - self.assertEqual(2.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(2.0, self.evaluate( + distribution.read_var(replica_local_var))) # Assigning 6.0 in cross replica context will assign a value of # 6.0/num_replicas to each replica. tlv_ops = replica_local_var.assign(6.0) @@ -1227,11 +1110,10 @@ class ReplicaLocalVariableAssignTest(test.TestCase): # On reading the replica local var we should get the assigned value back. # The value on all the replicas are added before being returned by # `read_var`. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignReplicaLocalVarMeanAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1239,24 +1121,22 @@ class ReplicaLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the MEAN of values # on all replicas which is the value assigned in replica context. - self.assertEqual(1.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(1.0, self.evaluate( + distribution.read_var(replica_local_var))) tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. - self.assertEqual(6.0, self.evaluate(dist.read_var(replica_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) class MockModel(object): @@ -1290,25 +1170,25 @@ class MiniModel(keras_training.Model): return self.fc(inputs) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyDefunTest(test.TestCase): - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") - - def _call_and_check(self, model_fn, inputs, expected_result, defuns, - two_variables=False): + def _call_and_check(self, distribution, model_fn, inputs, expected_result, + defuns, two_variables=False): cpu_dev = device_util.canonicalize("CPU:0") gpu_dev = device_util.canonicalize("GPU:0") devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_replica(model_fn, mock_model, *inputs, - run_concurrently=False) + result = distribution.extended.call_for_each_replica( + model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1320,18 +1200,15 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_device_graph_functions = dist.call_for_each_replica( - defun.get_concrete_function, - mock_model, *inputs, run_concurrently=False) + per_replica_graph_functions = ( + distribution.extended.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs)) for device in devices: - graph_function = per_device_graph_functions.get(device=device) + graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) - @test_util.run_in_graph_and_eager_modes() - def testVariableInDefun(self): - self._skip_eager_if_gpus_less_than(1) - + def testVariableInDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1339,12 +1216,9 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return times_two(mock_model) - self._call_and_check(model_fn, [], 2.5, [times_two]) - - @test_util.run_in_graph_and_eager_modes() - def testVariableInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 2.5, [times_two]) + def testVariableInNestedDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1356,12 +1230,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return two_x_plus_one(mock_model) - self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) - - @test_util.run_in_graph_and_eager_modes() - def testTwoVariablesInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 3.5, + [times_two, two_x_plus_one]) + def testTwoVariablesInNestedDefun(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1373,12 +1245,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return fn2(mock_model) - self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) - - @test_util.run_in_graph_and_eager_modes() - def testGradientTapeOverNestedDefuns(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2], + two_variables=True) + def testGradientTapeOverNestedDefuns(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1394,32 +1264,21 @@ class MirroredStrategyDefunTest(test.TestCase): [v.get() for v in mock_model.variables]) return grads - self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2], two_variables=True) - @test_util.run_in_graph_and_eager_modes() - def testPassPerDevice(self): - self._skip_eager_if_gpus_less_than(1) - + def testPassPerReplica(self, distribution): @function.defun def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) - self._call_and_check(fn1, [factors], expected_result, [fn1]) - - @test_util.run_in_graph_and_eager_modes() - def testTrain(self): - self._skip_eager_if_gpus_less_than(1) - - cpu_dev = device_util.canonicalize("CPU:0") - gpu_dev = device_util.canonicalize("GPU:0") - devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) + factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) - with dist.scope(): + def testTrain(self, distribution): + with distribution.scope(): mock_model = MiniModel() mock_model.call = function.defun(mock_model.call) @@ -1429,11 +1288,11 @@ class MirroredStrategyDefunTest(test.TestCase): gradients_fn = backprop.implicit_grad(loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = dist.call_for_each_replica( - gradients_fn, None, run_concurrently=False) + grads_and_vars = distribution.extended.call_for_each_replica( + gradients_fn, args=(None,)) optimizer = gradient_descent.GradientDescentOptimizer(0.25) - update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access + update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access if not context.executing_eagerly(): self.evaluate(variables.global_variables_initializer()) @@ -1445,30 +1304,73 @@ class MirroredStrategyDefunTest(test.TestCase): self.assertAllEqual([0.5], updated_var_values[1]) +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=context.num_gpus()), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=context.num_gpus()), + required_gpus=1) + ], + mode=["graph"])) class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): - def _get_distribution_strategy(self): + def _configure_distribution_strategy(self, distribution): cluster_spec = server_lib.ClusterSpec({ "worker": ["/job:worker/task:0", "/job:worker/task:1"] }) - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure(cluster_spec=cluster_spec) - return strategy + distribution.configure(cluster_spec=cluster_spec) - def test_num_replicas_in_sync(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - strategy = self._get_distribution_strategy() + def test_num_replicas_in_sync(self, distribution): + self._configure_distribution_strategy(distribution) # We calculate the total number of gpus across the workers(2) specified in # the cluster spec. - self.assertEqual(context.num_gpus() * 2, strategy.num_replicas_in_sync) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy(), - learning_rate=0.05) + self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) + + def testMinimizeLossGraph(self, distribution): + self._configure_distribution_strategy(distribution) + self._test_minimize_loss_graph(distribution, learning_rate=0.05) + + def testDeviceScope(self, distribution): + """Test the device scope of multi-worker MirroredStrategy.""" + self._configure_distribution_strategy(distribution) + with distribution.scope(): + a = constant_op.constant(1.) + with ops.device("/cpu:0"): + b = constant_op.constant(1.) + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") + + def testMakeInputFnIterator(self, distribution): + self._configure_distribution_strategy(distribution) + dataset_fn = lambda: dataset_ops.Dataset.range(100) + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [[i+j for j in range(num_gpus)] * num_workers + for i in range(0, 100, num_gpus)] + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess) class MultiWorkerMirroredStrategyTestWithChief( @@ -1488,6 +1390,19 @@ class MultiWorkerMirroredStrategyTestWithChief( strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) + def testMinimizeLossGraphCoreMirroredStrategy(self): + strategy = mirrored_strategy.CoreMirroredStrategy( + num_gpus_per_worker=context.num_gpus()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py deleted file mode 100644 index 2bfe0f3e7a66311c9b0673761b73382e477cb24b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for class MirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context - - -class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) - - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - - def testReplicaId(self): - self._test_replica_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - -class VariableCreatorStackTest(test.TestCase): - - def testCreatorStacksAreThreadLocal(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - - def model_fn(device_id): - assert isinstance(device_id, int) - - def thread_creator_fn(next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + ":thread_" + str(device_id) - - with variable_scope.variable_creator_scope(thread_creator_fn): - # Create a variable in this scope. - v = variable_scope.variable(1.0) - - # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_replica_context().merge_call( - lambda _: _) - return v - - def main_thread_creator(next_creator, *args, **kwargs): - # We are not using the underlying next_creator for test purposes. - del next_creator, args, kwargs - return "main_thread" - - with context.graph_mode(), \ - dist.scope(), \ - variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_replica(model_fn, dist.worker_device_index) - result = dist.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] - self.assertEquals(expected, result) - - -class MultiWorkerMirroredStrategyTest(test.TestCase): - - def testDeviceScope(self): - """Test the device scope of multi-worker MirroredStrategy.""" - with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure( - cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device("/cpu:0"): - b = constant_op.constant(1.) - self.assertEqual(a.device, "/job:worker/task:0") - self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 815644421e36cc397d6faebf9abd9c54bab557de..c492d8bafc9024ed059f05b92e5466f3702726b9 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -32,7 +32,8 @@ from tensorflow.python.training import moving_averages all_combinations = combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu], + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], mode=["graph"]) @@ -93,7 +94,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([10.0, 11.0]) val = constant_op.constant([1.0, 2.0]) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average( var, val, decay, zero_debias=False) @@ -121,7 +123,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([0.0, 0.0]) val = array_ops.placeholder(dtypes.float32) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average(var, val, decay) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 8bdf0012087a60cd7d4acfd4eaf0ee0742275655..421507232ac26915741d422d8a23008ddb7bf143 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,13 +20,12 @@ from __future__ import print_function import six -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -40,10 +39,16 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device, prefetch_on_device=None): - super(OneDeviceStrategy, self).__init__() + def __init__(self, device): + super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) + + +class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of OneDeviceStrategy.""" + + def __init__(self, container_strategy, device): + super(OneDeviceExtended, self).__init__(container_strategy) self._device = device - self._prefetch_on_device = prefetch_on_device self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): @@ -61,18 +66,29 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), [self._device], - self._prefetch_on_device) + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch.""" + return values.DatasetIterator(dataset, [("/job:localhost", [self._device])]) + + def _distribute_dataset(self, dataset_fn): + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), [self._device]) - def _broadcast(self, tensor, destinations): + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + return values.InputFunctionIterator( + input_fn, [("/job:localhost", [self._device])], + [distribute_lib.InputContext()]) + + def _broadcast_to(self, tensor, destinations): del destinations return tensor # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): + def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, + initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) @@ -84,7 +100,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs @@ -117,42 +133,25 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): - # We don't run `fn` in multiple threads in OneDeviceStrategy. - kwargs.pop("run_concurrently", None) - with ops.device(self._device), _OneDeviceReplicaContext(self): + def _call_for_each_replica(self, fn, args, kwargs): + strategy = self._container_strategy() + with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def map(self, map_over, fn, *args, **kwargs): - with ops.device(self._device): - return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - - def _reduce(self, aggregation, value, destinations): - del destinations - if not isinstance(value, values.MapOutput): - return value - l = value.get() - assert l - with ops.device(self._device): - if aggregation == vs.VariableAggregation.SUM: - return math_ops.add_n(l) - elif aggregation == vs.VariableAggregation.MEAN: - return math_ops.add_n(l) / len(l) - else: - assert False + def _reduce_to(self, reduce_op, value, destinations): + del reduce_op, destinations + return value - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, options, fn, var, *args, **kwargs) + return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -168,7 +167,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return value @property - def num_replicas(self): + def _num_replicas_in_sync(self): return 1 @property @@ -183,16 +182,33 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): del var_list return [self._device] - def _worker_device_index(self): - return 0 + @property + def experimental_should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for OneDeviceStrategy.""" def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) + self, + distribution_strategy, + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property - def device(self): - return self._distribution_strategy.worker_devices[0] + def devices(self): + return [self._distribution_strategy.extended.worker_devices[0]] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 3fb92273924a665bf2a1ee5fc94b75273b8c5f78..d46cd6f529e363f76bfa2b22339add63530cfde8 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -35,12 +36,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - def testReplicaId(self): self._test_replica_id(self._get_distribution_strategy()) @@ -48,6 +43,20 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + @test_util.run_in_graph_and_eager_modes + def testMakeInputFnIterator(self): + d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = d.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, d.extended.worker_devices, expected_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 0554f4a83bda28142d709020a5a648127d66eab0..fa4705af7cb592119f56686d1f693a156f7b4b13 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def run_step(): return control_flow_ops.group(distribution.unwrap( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built))) + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index bbfd94ed5c0dd5391db0f4e0043b66553b45270d..fc2d2b20c95f0260d8243b662a020ddee8a00b14 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -64,7 +64,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): Operations that occur only on the first replica (such as incrementing the global step), will occur on the first replica *of every worker*. - It is expected to call `call_for_each_replica(fn, *args, **kwargs)` for any + It is expected to call `call_for_each_replica(fn, ...)` for any operations which potentially can be replicated across replicas (i.e. multiple GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra caution needs to be taken: @@ -94,13 +94,21 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): ValueError: if `cluster_spec` is given but `task_type` or `task_id` is not. """ - super(ParameterServerStrategy, self).__init__() + super(ParameterServerStrategy, self).__init__( + ParameterServerExtended(self, num_gpus_per_worker)) + + +class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of ParameterServerStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + super(ParameterServerExtended, self).__init__(container_strategy) self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local(num_gpus_per_worker) # We typically don't need to do all-reduce in this strategy. - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossDeviceOps( + self._cross_device_ops = ( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_LOCAL_CPU)) def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, @@ -189,6 +197,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _initialize_local(self, num_gpus_per_worker): """Initialize internal devices for local training.""" + self._worker_device = "/job:localhost" # Define compute devices which is a list of device strings and one for each # replica. When there are GPUs, replicate operations on these GPUs. # Otherwise, place operations on CPU. @@ -221,20 +230,51 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "ParameterServerStrategy with compute_devices = %r, " "variable_device = %r", self._compute_devices, self._variable_device) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) - def _broadcast(self, tensor, destinations): - if not cross_tower_ops_lib.check_destinations(destinations): + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + num_input_pipelines = multi_worker_util.worker_count( + self._cluster_spec, self._task_type) + else: + input_pipeline_id = 0 + num_input_pipelines = 1 + input_context = distribute_lib.InputContext( + num_input_pipelines=num_input_pipelines, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, [input_context]) + + def _broadcast_to(self, tensor, destinations): + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): + return tensor + if not cross_device_ops_lib.check_destinations(destinations): destinations = self._compute_devices - return self._cross_tower_ops.broadcast(tensor, destinations) + return self._cross_device_ops.broadcast(tensor, destinations) + + def _allow_variable_partition(self): + return not context.executing_eagerly() # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_replicas > 1: + if self._num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, @@ -288,39 +328,37 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): with ops.device(self._variable_device): return var_creator(*args, **kwargs) - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica(self, fn, *args, **kwargs) + return mirrored_strategy._call_for_each_replica( + self._container_strategy(), fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): + if not self._cluster_spec: + return if destinations is None: return - for d in cross_tower_ops_lib.get_devices_from(destinations): + for d in cross_device_ops_lib.get_devices_from(destinations): d_spec = tf_device.DeviceSpec.from_string(d) if d_spec.job == self._task_type and d_spec.task != self._task_id: raise ValueError( "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( - self, aggregation, value, destinations) - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return self.broadcast(value.get(self._compute_devices[0]), destinations) - return self._cross_tower_ops.reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return [self.broadcast(v.get(self._compute_devices[0]), d) - for v, d in value_destination_pairs] + self, reduce_op, value, destinations) + return self._cross_device_ops.reduce( + reduce_op, value, destinations=destinations) + + def _batch_reduce_to(self, reduce_op, value_destination_pairs): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) - return self._cross_tower_ops.batch_reduce(aggregation, - value_destination_pairs) + return self._cross_device_ops.batch_reduce(reduce_op, + value_destination_pairs) def _select_single_value(self, structured): """Select any single values in `structured`.""" @@ -334,9 +372,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "You cannot update variable with a Mirrored object with multiple " "components %r when using ParameterServerStrategy. You must " "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerDevice): + elif isinstance(x, values.PerReplica): raise ValueError( - "You cannot update variable with a PerDevice object %r when using " + "You cannot update variable with a PerReplica object %r when using " "ParameterServerStrategy. You must specify a single value or a " "Mirrored with a single value" % x) else: @@ -344,30 +382,26 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -393,11 +427,11 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # variables. return array_ops.identity(var) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the strategy class. The strategy object will be re-initialized if `cluster_spec` is given but @@ -445,11 +479,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) @property - def num_replicas(self): - return len(self._compute_devices) - - @property - def num_replicas_in_sync(self): + def _num_replicas_in_sync(self): return len(self._compute_devices) @property @@ -465,11 +495,12 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return min(var_list, key=lambda x: x.name) @property - def between_graph(self): + def experimental_between_graph(self): + # TODO(yuefengz): Should this return False in the local case? return True @property - def should_init(self): + def experimental_should_init(self): return self._is_chief @property @@ -479,3 +510,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): @property def should_save_summary(self): return self._is_chief + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index b8d5d0ecafce700d0e061b132607965a33ca9cb6..1ada6a6ba493563cd56342854f8d84a8ed5a7d40 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -25,23 +25,29 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import distribution_strategy_context as ds_context from tensorflow.python.training import training_util CHIEF = run_config.TaskType.CHIEF @@ -49,6 +55,13 @@ WORKER = run_config.TaskType.WORKER PS = run_config.TaskType.PS +def _get_replica_id_integer(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if isinstance(replica_id, ops.Tensor): + replica_id = tensor_util.constant_value(replica_id) + return replica_id + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -85,8 +98,7 @@ class ParameterServerStrategyTestBase( config=sess_config) as sess, \ d.scope(): - # Define a variable outside the call_for_each_replica scope. This is not - # recommended. + # Define a variable outside the call_for_each_replica scope. n = variable_scope.get_variable('n', initializer=10.0) self.assertEqual(n.device, '/job:ps/task:0') @@ -94,9 +106,8 @@ class ParameterServerStrategyTestBase( if num_gpus == 0: last_part_device = 'device:CPU:0' else: - last_part_device = ( - 'device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + last_part_device = ('device:GPU:%d' % replica_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -178,6 +189,75 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) + def _test_device_assignment_distributed_enable_partitioner( + self, task_type, task_id, num_gpus): + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + num_shards = len(d.parameter_devices) + partitioner = partitioned_variables.fixed_size_partitioner(num_shards) + with ops.Graph().as_default(), \ + self.cached_session(target=self._default_target, + config=sess_config) as sess, \ + d.scope(): + + n = variable_scope.get_variable( + 'n', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + + for part_id, var in enumerate(n): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + + def model_fn(): + a = constant_op.constant([3.0, 5.0]) + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable( + 'x', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + x_add = x.assign_add(a, name='x_add') + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + for part_id, var in enumerate(x): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + self.assertEqual(var.device, x_add[part_id].device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x_add[0]): + y = variable_scope.get_variable( + 'y', + initializer=constant_op.constant([20.0, 10.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + y_add = y.assign_add( + [array_ops.identity(x_add[0]), + array_ops.identity(x_add[1])]) + + for part_id, var in enumerate(y): + self.assertEqual(var.device, '/job:ps/task:0') + self.assertEqual(y_add[part_id].device, var.device) + self.assertEqual(var.device, x_add[0].device) + + return x_add, y_add + + x, y = d.call_for_each_replica(model_fn) + + if context.num_gpus() >= 1: + variables.global_variables_initializer().run() + x_val, y_val = sess.run([x, y]) + if num_gpus < 1: + self.assertEqual(x_val, [13.0, 25.0]) + self.assertEqual(y_val, [33.0, 35.0]) + else: + x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] + y_expect = [ + 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus + ] + self.assertEqual(x_val, x_expect) + self.assertEqual(y_val, y_expect) + def _test_device_assignment_local(self, d, compute_device='CPU', @@ -192,18 +272,16 @@ class ParameterServerStrategyTestBase( if 'CPU' in compute_device: replica_compute_device = '/device:CPU:0' else: - replica_compute_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_compute_device = ('/device:GPU:%d' % replica_id) replica_compute_device = device_util.canonicalize( replica_compute_device) if 'CPU' in variable_device: replica_variable_device = '/device:CPU:0' else: - replica_variable_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_replica_context().replica_id) + replica_id = _get_replica_id_integer() + replica_variable_device = ('/device:GPU:%d' % replica_id) replica_variable_device = device_util.canonicalize( replica_variable_device) @@ -285,9 +363,9 @@ class ParameterServerStrategyTestBase( def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if 'chief' in d._cluster_spec.as_dict(): + if d.extended._cluster_spec: + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d.extended._cluster_spec.as_dict(): num_workers += 1 else: num_workers = 1 @@ -320,7 +398,7 @@ class ParameterServerStrategyTestBase( x, y, z, train_op = d.call_for_each_replica(model_fn) train_op = d.group(train_op) - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if task_id == 0: @@ -345,20 +423,25 @@ class ParameterServerStrategyTestBase( self._finish_condition.release() x_val, y_val, z_val = sess.run([x, y, z]) - self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas) - self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) self.assertEqual(z_val, 30.0 + 1.0 * num_workers) - return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas and - y_val == 20.0 + 1.0 * num_workers * d.num_replicas and + return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas_in_sync and + y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - assert hasattr(d, '_cluster_spec') and d._cluster_spec - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if CHIEF in d._cluster_spec.as_dict(): - num_workers += 1 + if task_type: + # Multi-worker + assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d.extended._cluster_spec.as_dict(): + num_workers += 1 + else: + # local + num_workers = 1 with ops.Graph().as_default(), \ self.cached_session(target=master_target, @@ -389,7 +472,7 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -399,7 +482,7 @@ class ParameterServerStrategyTestBase( with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -407,10 +490,12 @@ class ParameterServerStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True - if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id): + if (not task_type or + multi_worker_util.is_chief( + d.extended._cluster_spec, task_type, task_id)): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -433,8 +518,40 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_objects( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): @classmethod @@ -473,6 +590,12 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testDeviceAssignmentDistributed(self, num_gpus): self._test_device_assignment_distributed('worker', 1, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + self._test_device_assignment_distributed_enable_partitioner( + 'worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, self._cluster_spec, context.num_gpus()) @@ -484,10 +607,55 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): + def testMinimizeLossGraphDistributed(self, num_gpus): self._run_between_graph_clients(self._test_minimize_loss_graph, self._cluster_spec, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraphLocal(self, num_gpus): + self._test_minimize_loss_graph(None, None, num_gpus) + + # TODO(priyag): Refactor this and other multi worker tests. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorLocal(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) # only one worker and pipeline for local. + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + + def testGlobalStepUpdate(self): + strategy = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=context.num_gpus()) + self._test_global_step_update(strategy) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -530,9 +698,9 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, v = variable_scope.get_variable('v', initializer=10.0) _ = v * v v, = tape.watched_variables() - w = distribution.value_container(v) + w = distribution.extended.value_container(v) self.assertIs(values.AggregatingVariable, type(w)) - distribution.call_for_each_replica(f) + distribution.extended.call_for_each_replica(f) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index a5adaac47ceb3e22909bb852c6e3418446710a51..c928b6d9f1f21508edd753f94c38ab2723cc0a9f 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -90,25 +90,21 @@ class StandardSingleLossStep(StandardInputStep): super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer - self._is_run_concurrently = False self._iterations_per_step = iterations_per_step def __call__(self): with self._distribution.scope(): - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): """Function to run one iteration with one input.""" gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, - ctx, *inputs, - run_concurrently=self._is_run_concurrently) + gradients_fn, args=(ctx,) + inputs) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be # created. - self._is_run_concurrently = True return self._optimizer._distributed_apply( # pylint: disable=protected-access self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 60ef0a2106a17d2eede6acec9b9178d1a9d736ff..5a8e8ed0dda0b99e759edbe916a46dab953929a0 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,16 +19,21 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import distribution_strategy_context as ds_context from tensorflow.python.training import optimizer @@ -45,8 +50,7 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_replica() call, calls a # get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _raise_exception_fn) + ds_context.get_replica_context().merge_call(_raise_exception_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -59,8 +63,7 @@ def _call_raises_fn(dist): # calls a get_replica_context().merge_call() that calls a # call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_raises_fn) + ds_context.get_replica_context().merge_call(_call_raises_fn) # Must be the argument to a get_replica_context().merge_call() call, calls @@ -74,8 +77,7 @@ def _call_merge_raises_fn(dist): # get_replica_context().merge_call() that calls a call_for_each_replica() that # calls a get_replica_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribution_strategy_context.get_replica_context().merge_call( - _call_merge_raises_fn) + ds_context.get_replica_context().merge_call(_call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -104,7 +106,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one, run_concurrently=l.built) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -114,8 +116,7 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -160,7 +161,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -169,8 +170,7 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.reduce(reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -189,40 +189,20 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_map_reduce(self, d, in_graph=None): - with d.scope(): - map_in = [constant_op.constant(i) for i in range(10)] - map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out, - "/device:CPU:0") - expected = 90 # 2 * (0 + 1 + ... + 9) - self.assertEqual(expected, observed.numpy()) - - def _test_device_index(self, d): - with d.scope(): - expected_devices = [False] * len(d.worker_devices) - - def mark_devices_fn(device_id): - self.assertLess(device_id, len(d.worker_devices)) - self.assertFalse(expected_devices[device_id]) - expected_devices[device_id] = True - - d.call_for_each_replica(mark_devices_fn, d.worker_device_index) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) - def _test_replica_id(self, d): with d.scope(): - expected_devices = [False] * len(d.worker_devices) + expected_devices = [False] * len(d.extended.worker_devices) def mark_devices_fn(): - replica_id = ( - distribution_strategy_context.get_replica_context().replica_id) - self.assertLess(replica_id, len(d.worker_devices)) + replica_id = self.evaluate( + ds_context.get_replica_context().replica_id_in_sync_group) + self.assertLess(replica_id, len(d.extended.worker_devices)) self.assertFalse(expected_devices[replica_id]) expected_devices[replica_id] = True d.call_for_each_replica(mark_devices_fn) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + self.assertAllEqual(expected_devices, + [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): @@ -234,3 +214,78 @@ class DistributionTestBase(test.TestCase): dist.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): dist.call_for_each_replica(_merge_call_merge_raises_fn) + + def _input_fn_to_test_input_context(self, + dataset_fn, + expected_num_replicas_in_sync, + expected_num_input_pipelines, + expected_input_pipeline_id): + # Use a list of one element as counter so that it can be captured by the + # `_input_fn`. This counter is incremented by 1 each time an input_fn is + # called. We use this counter to check whether the `input_pipeline_id` + # matches the counter in the in-graph replication. + worker_id_counter = [0] + + def _input_fn(input_context): + """Input fn for testing.""" + self.assertIsNotNone(input_context) + self.assertEqual(expected_num_replicas_in_sync, + input_context.num_replicas_in_sync) + self.assertEqual(expected_num_input_pipelines, + input_context.num_input_pipelines) + if expected_input_pipeline_id is not None: + self.assertEqual(expected_input_pipeline_id, + input_context.input_pipeline_id) + else: + self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) + worker_id_counter[0] += 1 + + return dataset_fn() + + return _input_fn + + def _test_input_fn_iterator(self, iterator, devices, expected_values, + sess=None): + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + def _test_global_step_update(self, strategy): + with strategy.scope(): + global_step = variable_scope.get_variable( + "global_step", + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + train_op = global_step.assign_add(1) + value = global_step.read_value() + return train_op, value + + train_ops, value = strategy.call_for_each_replica(model_fn) + self.evaluate(strategy.group(train_ops)) + global_step_tensors = strategy.unwrap(value) + global_step_values = self.evaluate(global_step_tensors) + self.assertEqual([1] * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 65ef21df09ba34c274cdce73996bff7b9c32da85..f1115cb0c07666e9fe3a640cab6fb927d6d508c0 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -23,21 +23,23 @@ from __future__ import print_function import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -130,8 +132,21 @@ class TPUStrategy(distribute_lib.DistributionStrategy): num_cores: Number of cores to use on the TPU. If None specified, then auto-detect the cores and topology of the TPU system. """ - super(TPUStrategy, self).__init__() + super(TPUStrategy, self).__init__(TPUExtended( + self, tpu_cluster_resolver, steps_per_run, num_cores)) + @property + def steps_per_run(self): + """DEPRECATED: use .extended.steps_per_run instead.""" + return self._extended.steps_per_run + + +class TPUExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of TPUStrategy.""" + + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, + num_cores=None): + super(TPUExtended, self).__init__(container_strategy) self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override @@ -141,11 +156,11 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # parallelism. device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name} - self._device_index = values.PerDevice(device_map) + self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. - self._tpu_devices = self._tpu_devices[:self.num_replicas] + self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. @@ -214,20 +229,29 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return enqueue_op_per_host - def distribute_dataset(self, dataset_fn): - worker_map = { - self.get_host(hid): [self.get_host_cpu_device(hid)] + def _make_dataset_iterator(self, dataset): + """Make iterators for each of the TPU hosts.""" + + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) - } + ] + return values.DatasetIterator(dataset, worker_devices, + self._num_replicas_in_sync) + + def _distribute_dataset(self, dataset_fn): + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_map) + functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. - def _run_steps_on_dataset(self, fn, multi_worker_iterator, iterations, - initial_loop_values=None): - + def _experimental_run_steps_on_iterator( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any([not s.is_fully_defined() for s in shapes]): @@ -257,7 +281,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -279,7 +303,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - replicate_inputs = [[]] * self.num_replicas + replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) @@ -303,27 +327,27 @@ class TPUStrategy(distribute_lib.DistributionStrategy): last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, take the first value + # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. - if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - kwargs.pop("run_concurrently", None) - with _TPUReplicaContext(self): + with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def initialize(self): + def _initialize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") @@ -338,7 +362,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): tpu.initialize_system()) return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - def finalize(self): + def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. raise NotImplementedError("Eager mode not supported in TPUStrategy.") @@ -346,7 +370,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return [tpu.shutdown_system()] def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. + # TODO(jhseu): Change this when we support model parallelism. return self._tpu_devices def _create_variable(self, next_creator, *args, **kwargs): @@ -383,12 +407,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_replicas) - elif aggregation != vs.VariableAggregation.SUM: + value *= (1. / self._num_replicas_in_sync) + elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) @@ -396,27 +420,22 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. - devices = cross_tower_ops_lib.get_devices_from(destinations) + devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( self._host_device) else: raise ValueError("Multiple devices are not supported for TPUStrategy") - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return value[0] output = math_ops.add_n(value) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if should_group: + if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] @@ -431,9 +450,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - # TODO(josh11b): Need to implement _update_non_slot()! + return values.update_regroup(self, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -445,7 +462,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should - # be represented using a PerDevice wrapper instead of a list with + # be represented using a PerReplica wrapper instead of a list with # one entry per device. return val return [val] @@ -453,14 +470,10 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def value_container(self, value): return value - def _broadcast(self, tensor, destinations): + def _broadcast_to(self, tensor, destinations): del destinations return tensor - @property - def num_replicas(self): - return self._num_cores_override or self._tpu_metadata.num_cores - @property def num_hosts(self): return self._tpu_metadata.num_hosts @@ -470,15 +483,15 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return self._tpu_metadata.num_of_cores_per_host @property - def num_replicas_in_sync(self): - return self.num_replicas + def _num_replicas_in_sync(self): + return self._num_cores_override or self._tpu_metadata.num_cores @property - def between_graph(self): + def experimental_between_graph(self): return False @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -500,14 +513,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def non_slot_devices(self, var_list): return self._host_device - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._host_device), distribute_lib.UpdateContext( self._host_device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -521,11 +532,11 @@ class TPUStrategy(distribute_lib.DistributionStrategy): def get_host_cpu_device(self, host_id): return self.get_host(host_id) + "/device:CPU:0" - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): del cluster_spec, task_type, task_id if session_config: session_config.isolate_session_state = True @@ -533,6 +544,11 @@ class TPUStrategy(distribute_lib.DistributionStrategy): if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True + class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" @@ -540,9 +556,14 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): # TODO(sourabhbajaj): Call for each tower should be updating this. def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( - self, distribution_strategy, replica_id=0) + self, + distribution_strategy, + # TODO(b/118385803): properly initialize replica_id, instead of always 0 + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property - def device(self): + def devices(self): distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] + ds = self._distribution_strategy + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return [ds.extended.worker_devices[replica_id]] diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index d514e6f4c158d15665a2cd46be0547178da66544..855b9c29aec0c0a65f1a715eea764067a41ba2f3 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os +from absl.testing import parameterized -from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib @@ -35,10 +35,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -190,10 +192,10 @@ def _make_mirrored(): class RegroupAndSelectDeviceTest(test.TestCase): - def _is_per_device(self, result, expected, klass=values.PerDevice): + def _is_per_replica(self, result, expected, klass=values.PerReplica): self.assertIsInstance(result, klass) # We canonicalize the devices to match the device strings returned - # by PerDevice, which also does device string canonicalization. + # by PerReplica, which also does device string canonicalization. devices = [device_util.canonicalize(_device_str(i)) for i in range(len(expected))] self.assertEqual(set(devices), set(result.devices)) @@ -206,18 +208,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"]) - self._is_per_device(result[2], ["h1", "h2"]) + self._is_per_replica(result[0], ["a1", "a2"]) + self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"]) - self._is_per_device(result[1][2], ["g1", "g2"]) + self._is_per_replica(result[1][0], ["b1", "b2"]) + self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"]) - self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -238,18 +240,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) - self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) - self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) - self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -275,7 +277,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): ("b", foo)}) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) - self._is_per_device(result[0], ["a", "b"]) + self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_device(), should undo the merge done by regroup(). @@ -325,69 +327,46 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) - self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) - self.assertEquals(created_estimator_specs[device_id].loss, - merged_estimator_spec.loss.get(d)) - self.assertEquals(created_estimator_specs[device_id].train_op, - merged_estimator_spec.train_op.get(d)) + self.assertEqual(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEqual(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. - self.assertEquals(created_estimator_specs[device_id].scaffold, - merged_estimator_spec.scaffold.get(d)) + self.assertEqual(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() - self.assertEquals(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) + self.assertEqual(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) -class PerDeviceDatasetTest(test.TestCase): +class PerReplicaDatasetTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True - def _test_iterator_no_prefetch(self, devices, dataset, expected_values): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) + def _test_iterator(self, devices, dataset, expected_values): + per_replica_dataset = values.PerReplicaDataset(dataset, devices) if context.executing_eagerly(): - iterator = per_device_dataset.make_one_shot_iterator() + iterator = per_replica_dataset.make_one_shot_iterator() else: - iterator = per_device_dataset.make_initializable_iterator() + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() - actual = self.evaluate([ - values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate([ values.select_device(d, next_element) for d in devices]) - def _test_iterator_with_prefetch(self, devices, dataset, expected_values): - if not context.executing_eagerly(): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - def _test_iterator(self, devices, dataset, expected_values): - self._test_iterator_no_prefetch(devices, dataset, expected_values) - self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] @@ -442,9 +421,8 @@ class PerDeviceDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10,))) - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_initializable_iterator() + per_replica_dataset = values.PerReplicaDataset(dataset, devices) + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next() @@ -463,7 +441,7 @@ class PerDeviceDatasetTest(test.TestCase): class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - def _test_iterator(self, iterator, devices, expected_values): + def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() for device in devices: v = values.select_device(device, next_element) @@ -472,73 +450,79 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): self.assertTrue(element.device in device) for expected_value in expected_values: - actual = self.evaluate( + actual = sess.run( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): - self.evaluate([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_device(d, next_element) for d in devices]) - def _test_dataset(self, dataset_fn, worker_device_map, devices, - expected_values): + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values, auto_shard=True): multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) - multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() - self._test_iterator(multi_worker_iterator, devices, expected_values) + dataset_fn, worker_devices, auto_shard=auto_shard) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) def _cpu_devices(self): - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])]) + worker_devices = [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def _cpu_and_one_gpu_devices(self): - # The worker_device_map doesn't have to be a OrderDict object, this is just - # to simplify the testing so that we can pass expected values as a list - # instead of a dict. - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", [ + worker_devices = [ + ("/job:worker/replica:0/task:0", [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0" - ]), ("/job:worker/replica:0/task:1", [ + ]), + ("/job:worker/replica:0/task:1", [ "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" - ])]) + ]) + ] devices = [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def testDataDistributionOneDevicePerWorker(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) + def testDataDistributionNoAutoShard(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_dataset(dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], + auto_shard=False) + def testDataDistributionTwoDevicePerWorker(self): - self.skipTest("Temporarily disabled.") if context.num_gpus() < 1: self.skipTest("A GPU is not available for this test.") - worker_device_map, devices = self._cpu_and_one_gpu_devices() + worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 2, 1, 3], [4, 6, 5, 7]]) def testTupleDataset(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): @@ -550,47 +534,221 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): expected_values = [ [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) ] - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, expected_values) def testInitializableIterator(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() - with context.graph_mode(): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) # After re-initializing the iterator, should be able to iterate again. - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) def testValueErrorForIterator(self): - self.skipTest("Temporarily disabled.") # Incompatiable arguments. with self.assertRaises(ValueError): values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) # Test duplicated devices under same worker. - worker_device_map, _ = self._cpu_devices() - worker_device_map["/job:worker/replica:0/task:0"].append( - "/job:worker/replica:0/task:0/device:CPU:0") + worker_devices, _ = self._cpu_devices() + worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.assertRaises(ValueError): multi_worker_iterator.get_next() -class MirroredVariableTest(test.TestCase): +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, + input_contexts) + else: + iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, + split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -602,9 +760,9 @@ class MirroredVariableTest(test.TestCase): v, _, mirrored = _make_mirrored() - self.assertEquals(v[0].name, mirrored.name) - self.assertEquals(v[0].dtype, mirrored.dtype) - self.assertEquals(v[0].shape, mirrored.shape) + self.assertEqual(v[0].name, mirrored.name) + self.assertEqual(v[0].dtype, mirrored.dtype) + self.assertEqual(v[0].shape, mirrored.shape) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -614,9 +772,9 @@ class MirroredVariableTest(test.TestCase): mirrored = values.MirroredVariable(index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, mirrored.name) - self.assertEquals(v.dtype, mirrored.dtype) - self.assertEquals(v.shape, mirrored.shape) + self.assertEqual(v.name, mirrored.name) + self.assertEqual(v.dtype, mirrored.dtype) + self.assertEqual(v.shape, mirrored.shape) def _assign_mirrored(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -736,14 +894,13 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testFetchAMirroredVariable(self): - if context.num_gpus() < 1 or context.executing_eagerly(): - self.skipTest("A GPU is not available for this test or it's eager mode.") - - with self.session( - graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( - ["/device:GPU:0"]).scope(): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph"])) + def testFetchAMirroredVariable(self, distribution): + with self.session(graph=ops.Graph()) as sess, distribution.scope(): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) @@ -769,7 +926,7 @@ def _make_replica_local(method): return v, replica_local -class ReplicaLocalVariableTest(test.TestCase): +class ReplicaLocalVariablePropertiesTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -778,15 +935,14 @@ class ReplicaLocalVariableTest(test.TestCase): def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) - self.assertEquals(v[0].name, replica_local.name) - self.assertEquals(v[0].dtype, replica_local.dtype) - self.assertEquals(v[0].shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.SUM, - replica_local.aggregation) + self.assertEqual(v[0].name, replica_local.name) + self.assertEqual(v[0].dtype, replica_local.dtype) + self.assertEqual(v[0].shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.SUM, + replica_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -796,11 +952,32 @@ class ReplicaLocalVariableTest(test.TestCase): replica_local = values.ReplicaLocalVariable( index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, replica_local.name) - self.assertEquals(v.dtype, replica_local.dtype) - self.assertEquals(v.shape, replica_local.shape) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - replica_local.aggregation) + self.assertEqual(v.name, replica_local.name) + self.assertEqual(v.dtype, replica_local.dtype) + self.assertEqual(v.shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + replica_local.aggregation) + + def testTensorConversion(self): + with context.graph_mode(): + _, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): def _assign_replica_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -817,22 +994,15 @@ class ReplicaLocalVariableTest(test.TestCase): save_path, _ = self._save_return_saver(sess, var) return save_path - def _dist_scope(self): - return mirrored_strategy.MirroredStrategy(_devices).scope() - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalSumOneGraph(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - with self.cached_session(config=self.config) as sess: + def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 7. save_path, saver = self._save_return_saver(sess, replica_local) @@ -844,19 +1014,18 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreReplicaLocalMeanOneGraph(self): + def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.cached_session(config=self.config) as sess: + with self.cached_session() as sess: v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5. save_path, saver = self._save_return_saver(sess, replica_local) @@ -867,7 +1036,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _save_replica_local_mean(self): + def _save_replica_local_mean(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -876,7 +1045,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5 save_path = self._save(sess, replica_local) @@ -884,7 +1053,7 @@ class ReplicaLocalVariableTest(test.TestCase): self._assign_replica_local(_devices, v, [5., 6.]) return save_path - def _save_replica_local_sum(self): + def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local("sum") @@ -892,7 +1061,7 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [1.5, 2.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 3.5 save_path = self._save(sess, replica_local) @@ -930,7 +1099,7 @@ class ReplicaLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual(3.5, self.evaluate(var)) - def _restore_replica_local_mean(self, save_path): + def _restore_replica_local_mean(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -939,13 +1108,13 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _restore_replica_local_sum(self, save_path): + def _restore_replica_local_sum(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: v, replica_local = _make_replica_local( @@ -954,72 +1123,35 @@ class ReplicaLocalVariableTest(test.TestCase): # Overwrite the initial values. self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): + save_path = self._save_replica_local_mean(distribution) + self._restore_replica_local_mean(save_path, distribution) - save_path = self._save_replica_local_mean() - self._restore_replica_local_mean(save_path) + def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): + save_path = self._save_replica_local_sum(distribution) + self._restore_replica_local_sum(save_path, distribution) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() - self._restore_replica_local_sum(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalMeanRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_mean() + def testSaveReplicaLocalMeanRestoreNormal(self, distribution): + save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveReplicaLocalSumRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_replica_local_sum() + def testSaveReplicaLocalSumRestoreNormal(self, distribution): + save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - + def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() - self._restore_replica_local_mean(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreReplicaLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + self._restore_replica_local_mean(save_path, distribution) + def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() - self._restore_replica_local_sum(save_path) - - def testTensorConversion(self): - with context.graph_mode(): - _, replica_local = _make_replica_local( - variable_scope.VariableAggregation.SUM) - converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) - - converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) - # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, replica_local.dtype) + self._restore_replica_local_sum(save_path, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py index 5d57d144c1c16a08280970ecd89eb54f7cf1ffd4..b0bcf9b17456c938204a4892451928daf90b6743 100644 --- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -44,7 +44,9 @@ class WarmStartingUtilWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], save_with_distribution=[True, False], restore_with_distribution=[True, False], mode=["graph"])) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 60f6b90edcb71f04bca29b90744db201e83cd545..3079175015a9aee1625404902070df8f13b2089c 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -72,7 +72,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -80,6 +79,7 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/signal", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index aa680a92be64cf0f099acd335369f2a1610c5953..978e627d6638ddeea9df288d389354f0ac53d115 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -29,8 +29,8 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import spectral_ops from tensorflow.python.ops.distributions import util +from tensorflow.python.ops.signal import fft_ops __all__ = [ "auto_correlation", @@ -157,11 +157,11 @@ def auto_correlation( dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). - fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + fft_x_rotated_pad = fft_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). - shifted_product = spectral_ops.ifft(spectral_density) + shifted_product = fft_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 3aed121233be1268531495a2fa83fd323412e1fd..db77a39626900ec4d46263b1891e08c0262ce7da 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.experimental.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context from tensorflow.python.framework import ops @@ -52,11 +53,16 @@ class Iterator(iterator_ops.EagerIterator): TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ - if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access + # pylint: disable=protected-access + if (isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset) + or (isinstance(dataset, dataset_ops.DatasetV1Adapter) + and isinstance( + dataset._dataset, prefetching_ops._PrefetchToDeviceDataset))): raise TypeError( "`tf.data.experimental.prefetch_to_device()` is not compatible with " "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " "over the dataset instead.") + # pylint: enable=protected-access if not context.context().device_spec.device_type: is_remote_device = False diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 480777d948769b56ac1cc3be2052fe48459e98d6..66d52a74943d0d81fde05ce51b019558b327978d 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -768,7 +768,7 @@ }, "outputs": [], "source": [ - "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -781,7 +781,7 @@ }, "outputs": [], "source": [ - "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -794,7 +794,7 @@ }, "outputs": [], "source": [ - "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -808,7 +808,7 @@ "outputs": [], "source": [ "# wrong translation\n", - "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index f384d761a8430074f022c973d7ec3d46cd90f70b..3eb396a29ccdc0478384f9fa122465731740a30d 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,7 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 1ab5418fe4659cb0068ee8c3ca1442f6f723ee76..2f7cd131d3ed20df307ed231cce2ecb50ecfbceb 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,7 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config -from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index a926ffd5982116a21dc7a0fd1ff957d4ecc6bf94..1cd83bdb5de7c2f6dc91c980750b49aca1a7790b 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -14,6 +14,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":sequence_feature_column", + ":sequence_feature_column_v2", "//tensorflow/python:util", ], ) @@ -32,7 +33,7 @@ py_library( "//tensorflow/python:sparse_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", ], ) @@ -51,7 +52,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -69,7 +70,49 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras:layers", ], ) + +py_library( + name = "sequence_feature_column_v2", + srcs = ["python/feature_column/sequence_feature_column_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + ], +) + +py_test( + name = "sequence_feature_column_v2_test", + srcs = ["python/feature_column/sequence_feature_column_v2_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + ":sequence_feature_column_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index dd6da35ed009c07ad3819e7860a283c7837c1f83..9b3a5c58aaa9498257fc971ac60b97f31d5185d8 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -222,10 +222,8 @@ def sequence_categorical_column_with_identity( ValueError: if `default_value` is not in range `[0, num_buckets)`. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_identity( - key=key, - num_buckets=num_buckets, - default_value=default_value)) + fc._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) def sequence_categorical_column_with_hash_bucket( @@ -265,10 +263,8 @@ def sequence_categorical_column_with_hash_bucket( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_hash_bucket( - key=key, - hash_bucket_size=hash_bucket_size, - dtype=dtype)) + fc._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) def sequence_categorical_column_with_vocabulary_file( @@ -324,7 +320,7 @@ def sequence_categorical_column_with_vocabulary_file( ValueError: `dtype` is neither string nor integer. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_file( + fc._categorical_column_with_vocabulary_file( key=key, vocabulary_file=vocabulary_file, vocabulary_size=vocabulary_size, @@ -384,7 +380,7 @@ def sequence_categorical_column_with_vocabulary_list( ValueError: if `dtype` is not integer or string. """ return fc._SequenceCategoricalColumn( - fc.categorical_column_with_vocabulary_list( + fc._categorical_column_with_vocabulary_list( key=key, vocabulary_list=vocabulary_list, dtype=dtype, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py index d8ca363627eace15e039679545366648df174c33..bcc25b8de895a769f9e11b207c2092e23d029b1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py @@ -53,19 +53,20 @@ class SequenceFeatureColumnIntegrationTest(test.TestCase): return example def _build_feature_columns(self): - col = fc.categorical_column_with_identity( - 'int_ctx', num_buckets=100) + col = fc._categorical_column_with_identity('int_ctx', num_buckets=100) ctx_cols = [ - fc.embedding_column(col, dimension=10), - fc.numeric_column('float_ctx')] + fc._embedding_column(col, dimension=10), + fc._numeric_column('float_ctx') + ] identity_col = sfc.sequence_categorical_column_with_identity( 'int_list', num_buckets=10) bucket_col = sfc.sequence_categorical_column_with_hash_bucket( 'bytes_list', hash_bucket_size=100) seq_cols = [ - fc.embedding_column(identity_col, dimension=10), - fc.embedding_column(bucket_col, dimension=20)] + fc._embedding_column(identity_col, dimension=10), + fc._embedding_column(bucket_col, dimension=20) + ] return ctx_cols, seq_cols @@ -148,8 +149,8 @@ class SequenceExampleParsingTest(test.TestCase): """ example = _make_sequence_example() columns = [ - fc.categorical_column_with_identity('int_ctx', num_buckets=100), - fc.numeric_column('float_ctx'), + fc._categorical_column_with_identity('int_ctx', num_buckets=100), + fc._numeric_column('float_ctx'), col_fn(col_name, col_arg) ] context, seq_features = parsing_ops.parse_single_sequence_example( diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 2163af0b43864c96483df529f07881f2f985a80e..d5f74028298ee7015f5b2e3aaee7d9330c1acac1 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column import feature_column_lib as fc_lib from tensorflow.python.feature_column.feature_column import _LazyBuilder from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -109,13 +110,15 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=embedding_dimension_a, + embedding_column_a = fc._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_b = fc.embedding_column( - categorical_column_b, dimension=embedding_dimension_b, + embedding_column_b = fc._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) input_layer, sequence_length = sfc.sequence_input_layer( @@ -148,10 +151,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -206,7 +208,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) @@ -244,11 +246,11 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( + categorical_column_b = fc._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) with self.assertRaisesRegexp( @@ -315,10 +317,10 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) - indicator_column_b = fc.indicator_column(categorical_column_b) + indicator_column_b = fc._indicator_column(categorical_column_b) input_layer, sequence_length = sfc.sequence_input_layer( features={ 'aaa': sparse_input_a, @@ -342,9 +344,9 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): values=(2, 0, 1), dense_shape=(2, 2)) - categorical_column_a = fc.categorical_column_with_identity( + categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -530,7 +532,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) input_layer, _ = sfc.sequence_input_layer( features={'aaa': sparse_input}, feature_columns=[indicator_column]) @@ -616,8 +618,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) + embedding_column_a = fc._embedding_column(categorical_column_a, dimension=2) with self.assertRaisesRegexp( ValueError, @@ -639,7 +640,7 @@ class InputLayerTest(test.TestCase): categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) + indicator_column_a = fc._indicator_column(categorical_column_a) with self.assertRaisesRegexp( ValueError, @@ -918,8 +919,9 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=embedding_dimension, + embedding_column = fc._embedding_column( + categorical_column, + dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( @@ -956,8 +958,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -984,8 +985,7 @@ class SequenceEmbeddingColumnTest( categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, dimension=2) + embedding_column = fc._embedding_column(categorical_column, dimension=2) _, sequence_length = embedding_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) @@ -1055,7 +1055,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) @@ -1101,7 +1101,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1152,7 +1152,7 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase): categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - shared_embedding_columns = fc.shared_embedding_columns( + shared_embedding_columns = fc_lib.shared_embedding_columns( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( @@ -1218,7 +1218,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1250,7 +1250,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': inputs})) @@ -1277,7 +1277,7 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - indicator_column = fc.indicator_column(categorical_column) + indicator_column = fc._indicator_column(categorical_column) _, sequence_length = indicator_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..0d34ad161855476b6a4cd9a258521dbe122b4140 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -0,0 +1,558 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This API defines FeatureColumn for sequential input. + +NOTE: This API is a work in progress and will likely be changing frequently. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import collections + + +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import feature_column_lib as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# pylint: disable=protected-access + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + features: A dict mapping keys to tensors. + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc_old._normalize_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, fc_old._SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'You can wrap a sequence_categorical_column with an embedding_column ' + 'or indicator_column. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + None, default_name='sequence_input_layer', values=features.values()): + builder = fc_old._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + target_shape = [shape[0], shape[1], num_elements] + output_tensors.append( + array_ops.reshape(dense_tensor, shape=target_shape)) + sequence_lengths.append(sequence_length) + + fc_old._verify_static_batch_size_equality(output_tensors, ordered_columns) + fc_old._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + + return array_ops.concat(output_tensors, -1), sequence_length + + +def concatenate_context_input(context_input, sequence_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + """Returns a feature column that represents sequences of integers. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [watches_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + num_buckets: Range of inputs. Namely, inputs are expected to be in the + range `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace out-of-range inputs. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + return fc_old._SequenceCategoricalColumn( + fc_old._categorical_column_with_identity( + key=key, num_buckets=num_buckets, default_value=default_value)) + + +def sequence_categorical_column_with_hash_bucket( + key, hash_bucket_size, dtype=dtypes.string): + """A sequence of categorical terms where ids are set by hashing. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + tokens = sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=1000) + tokens_embedding = embedding_column(tokens, dimension=10) + columns = [tokens_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + return fc_old._SequenceCategoricalColumn( + fc_old._categorical_column_with_hash_bucket( + key=key, hash_bucket_size=hash_bucket_size, dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A sequence of categorical terms where ids use a vocabulary file. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + states = sequence_categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + states_embedding = embedding_column(states, dimension=10) + columns = [states_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + return fc_old._SequenceCategoricalColumn( + fc_old._categorical_column_with_vocabulary_file( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A sequence of categorical terms where ids use an in-memory list. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + colors = sequence_categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + colors_embedding = embedding_column(colors, dimension=3) + columns = [colors_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + return fc_old._SequenceCategoricalColumn( + fc_old._categorical_column_with_vocabulary_list( + key=key, + vocabulary_list=vocabulary_list, + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32, + normalizer_fn=None): + """Returns a feature column that represents sequences of numeric data. + + Example: + + ```python + temperature = sequence_numeric_column('temperature') + columns = [temperature] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + sequence_feature_layer = SequenceFeatureLayer(columns) + input_layer, sequence_length = sequence_feature_layer(features) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input features. + shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`, + each example must contain `2 * sequence_length` values. + default_value: A single value compatible with `dtype` that is used for + padding the sparse data into a dense `Tensor`. + dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. + + Returns: + A `SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = fc_old._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + + return SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype, + normalizer_fn=normalizer_fn) + + +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + +class SequenceNumericColumn( + fc.SequenceDenseColumn, + collections.namedtuple( + 'SequenceNumericColumn', + ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): + """Represents sequences of numeric data.""" + + @property + def _is_v2_column(self): + return True + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def transform_feature(self, transformation_cache, state_manager): + """See `FeatureColumn` base class. + + In this case, we apply the `normalizer_fn` to the input tensor. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Normalized input tensor. + """ + input_tensor = transformation_cache.get(self.key, state_manager) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor + + @property + def variable_shape(self): + """Returns a `TensorShape` representing the shape of sequence input.""" + return tensor_shape.TensorShape(self.shape) + + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """Returns a `TensorSequenceLengthPair`. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + """ + sp_tensor = transformation_cache.get(self, state_manager) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self.variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + + # Get the number of timesteps per example + # For the 2D case, the raw values are grouped according to num_elements; + # for the 3D case, the grouping happens in the third dimension, and + # sequence length is not affected. + num_elements = (self.variable_shape.num_elements() + if sp_tensor.shape.ndims == 2 else 1) + seq_length = fc_old._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=num_elements) + + return fc.SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=seq_length) + + # TODO(b/119409767): Implement parents, _{get,from}_config. + @property + def parents(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + def _get_config(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + @classmethod + def _from_config(cls, config, custom_objects=None, columns_by_name=None): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4398a142065de0be7bee57cd7e54670bbae12e --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2_test.py @@ -0,0 +1,1508 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc_old +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import feature_column_lib as fc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [2, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[2], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 2, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]], + # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_embedding_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + vocabulary_size = 3 + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old._embedding_column( + categorical_column_a, + dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = fc_old._embedding_column( + categorical_column_b, + dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old._categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old._embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old._categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_old._categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args_a': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'sparse_input_args_b': { + # example 0, ids [1] + # example 1, ids [1, 0] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 1, 0), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'sparse_input_args_a': { + # feature 0, ids [[2], [0, 1]] + # feature 1, ids [[0, 0], [1]] + 'indices': ( + (0, 0, 0), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 0, 0, 1), + 'dense_shape': (2, 2, 2)}, + 'sparse_input_args_b': { + # feature 0, ids [[1, 1], [1]] + # feature 1, ids [[1], [0]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (1, 1, 1, 1, 0), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -] + [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]], + # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -] + [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_indicator_column( + self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, + expected_sequence_length): + sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) + sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) + + vocabulary_size_a = 3 + vocabulary_size_b = 2 + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size_a) + indicator_column_a = fc_old._indicator_column(categorical_column_a) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size_b) + indicator_column_b = fc_old._indicator_column(categorical_column_b) + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[indicator_column_b, indicator_column_a]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_indicator_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc_old._categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc_old._indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected_input_layer': [ + [[0.], [1.]], + [[10.], [0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected_input_layer': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]], + 'expected_sequence_length': [2, 2]}, + ) + def test_numeric_column( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + + numeric_column = sfc_old.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.] + # example 1, [10., 11., 12., 13.] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_input_layer': [ + # The output of numeric_column._get_dense_tensor should be flattened. + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]]], + 'expected_sequence_length': [2, 1]}, + ) + def test_numeric_column_multi_dim( + self, sparse_input_args, expected_input_layer, expected_sequence_length): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + + numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc_old.sequence_numeric_column('aaa') + numeric_column_b = sfc_old.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_shape': [2, 2, 4]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]] + # example 1, [[10., 11., 12., 13.], []] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2), + (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 4)}, + 'expected_shape': [2, 2, 4]}, + ) + def test_static_shape_from_tensors_numeric( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc_old.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected_shape': [4, 2, 3]}, + {'testcase_name': '3D', + 'sparse_input_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected_shape': [4, 2, 3]} + ) + def test_static_shape_from_tensors_indicator( + self, sparse_input_args, expected_shape): + """Tests that we return a known static shape when we have one.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=3) + indicator_column = fc_old._indicator_column(categorical_column) + + input_layer, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, feature_columns=[indicator_column]) + shape = input_layer.get_shape() + self.assertEqual(shape, expected_shape) + + +class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): + """Tests the utility fn concatenate_context_input.""" + + def test_concatenate_context_input(self): + seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2)) + context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + input_layer = sfc.concatenate_context_input(context_input, seq_input) + + expected = np.array([ + [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]], + [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]] + ], dtype=np.float32) + with monitored_session.MonitoredSession() as sess: + output = sess.run(input_layer) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_3', + 'seq_input_arg': np.arange(100).reshape(10, 10)}, + {'testcase_name': 'rank_gt_3', + 'seq_input_arg': np.arange(100).reshape(5, 5, 2, 2)} + ) + def test_sequence_input_throws_error(self, seq_input_arg): + seq_input = ops.convert_to_tensor(seq_input_arg) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'): + sfc.concatenate_context_input(context_input, seq_input) + + @parameterized.named_parameters( + {'testcase_name': 'rank_lt_2', + 'context_input_arg': np.arange(100)}, + {'testcase_name': 'rank_gt_2', + 'context_input_arg': np.arange(100).reshape(5, 5, 4)} + ) + def test_context_input_throws_error(self, context_input_arg): + context_input = ops.convert_to_tensor(context_input_arg) + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_seq_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + context_input = math_ops.cast(context_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'sequence_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + def test_integer_context_input_throws_error(self): + seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4)) + context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10)) + seq_input = math_ops.cast(seq_input, dtype=dtypes.float32) + with self.assertRaisesRegexp( + TypeError, 'context_input must have dtype float32'): + sfc.concatenate_context_input(context_input, seq_input) + + +class InputLayerTest(test.TestCase): + """Tests input_layer with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc_old._embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc_old.input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc_old._indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc_old.input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + _assert_sparse_tensor_indices_shape(test_case, expected, actual) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + +def _assert_sparse_tensor_indices_shape(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (1, 2, 0), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((1, 2, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': (6, 7, 8), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithHashBucketTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'stringer', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + # Ignored to avoid hash dependence in test. + 'values': np.array((0, 0, 0), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_hash_bucket( + 'aaa', hash_bucket_size=10) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_indices_shape( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyFileTest( + test.TestCase, parameterized.TestCase): + + def _write_vocab(self, vocab_strings, file_name): + vocab_file = os.path.join(self.get_temp_dir(), file_name) + with open(vocab_file, 'w') as f: + f.write('\n'.join(vocab_strings)) + return vocab_file + + def setUp(self): + super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp() + + vocab_strings = ['omar', 'stringer', 'marlo'] + self._wire_vocabulary_file_name = self._write_vocab(vocab_strings, + 'wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_dynamic_zero_length(self): + """Tests _get_sparse_tensors with a dynamic sequence length.""" + inputs = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 2)), values=[], dense_shape=(2, 0)) + expected = sparse_tensor.SparseTensorValue( + indices=np.zeros((0, 3)), + values=np.array((), dtype=np.int64), + dense_shape=(2, 0, 1)) + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + input_placeholder_shape = list(inputs.dense_shape) + # Make second dimension (sequence length) dynamic. + input_placeholder_shape[1] = None + input_placeholder = array_ops.sparse_placeholder( + dtypes.string, shape=input_placeholder_shape) + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': input_placeholder})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + result = id_weight_pair.id_tensor.eval( + session=sess, feed_dict={input_placeholder: inputs}) + _assert_sparse_tensor_value( + self, expected, result) + + +class SequenceCategoricalColumnWithVocabularyListTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': ('marlo', 'skywalker', 'omar'), + 'dense_shape': (2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 0), (1, 0, 0), (1, 1, 0)), + 'values': np.array((2, -1, 0), dtype=np.int64), + 'dense_shape': (2, 2, 1)}}, + {'testcase_name': '3D', + 'inputs_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': ('omar', 'skywalker', 'marlo'), + 'dense_shape': (2, 2, 2)}, + 'expected_args': { + 'indices': ((0, 0, 2), (1, 0, 0), (1, 2, 0)), + 'values': np.array((0, -1, 2), dtype=np.int64), + 'dense_shape': (2, 2, 2)}} + ) + def test_get_sparse_tensors(self, inputs_args, expected_args): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + expected = sparse_tensor.SparseTensorValue(**expected_args) + column = sfc.sequence_categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, expected, id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest( + test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [0, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 0, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[7., 11.], [0., 0.]], + # example 1, ids [[0, 1], [2]] + [[2, 3.5], [7., 11.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [[1], [0, 2]] + [[3., 5.], [4., 6.5]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old._embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected, embedding_lookup.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old._embedding_column(categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + +class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + 'indices': ((0, 0), (1, 0), (1, 1), (3, 0)), + 'values': (2, 0, 1, 1), + 'dense_shape': (4, 2)}, + 'expected': [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + # example 2, ids [] + # example 3, ids [[1], [2, 2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0), + (3, 0, 0), (3, 1, 0), (3, 1, 1)), + 'values': (2, 0, 1, 2, 1, 2, 2), + 'dense_shape': (4, 2, 2)}, + 'expected': [ + # example 0, ids [[2]] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [[0, 1], [2]] + [[1., 1., 0.], [0., 0., 1.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [[1], [2, 2]] + [[0., 1., 0.], [0., 0., 2.]]]} + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc_old._indicator_column(categorical_column) + + indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected, indicator_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2, 0, 1), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2]}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2, 0, 1, 2), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2]} + ) + def test_sequence_length(self, inputs_args, expected_sequence_length): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + vocabulary_size = 3 + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc_old._indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +def _get_sequence_dense_tensor(column, features): + return column.get_sequence_dense_tensor( + fc.FeatureTransformationCache(features), None) + + +class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase): + + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, values [0., 1] + # example 1, [10.] + 'indices': ((0, 0), (0, 1), (1, 0)), + 'values': (0., 1., 10.), + 'dense_shape': (2, 2)}, + 'expected': [ + [[0.], [1.]], + [[10.], [0.]]]}, + {'testcase_name': '3D', + 'inputs_args': { + # feature 0, ids [[20, 3], [5]] + # feature 1, ids [[3], [8]] + 'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)), + 'values': (20, 3, 5., 3., 8.), + 'dense_shape': (2, 2, 2)}, + 'expected': [ + [[20.], [3.], [5.], [0.]], + [[3.], [0.], [8.], [0.]]]}, + ) + def test_get_sequence_dense_tensor(self, inputs_args, expected): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': inputs}) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'sparse_input_args': { + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + 'indices': ((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), + (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 8)}, + 'expected_dense_tensor': [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]}, + {'testcase_name': '3D', + 'sparse_input_args': { + 'indices': ((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6), + (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6), + (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)), + 'values': (0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + 'dense_shape': (2, 2, 8)}, + 'expected_dense_tensor': [ + [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]], + [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]], + [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]], + [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]}, + ) + def test_get_dense_tensor_multi_dim( + self, sparse_input_args, expected_dense_tensor): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + @parameterized.named_parameters( + {'testcase_name': '2D', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '3D', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (1,)}, + {'testcase_name': '2D_with_shape', + 'inputs_args': { + # example 0, ids [2] + # example 1, ids [0, 1] + 'indices': ((0, 0), (1, 0), (1, 1)), + 'values': (2., 0., 1.), + 'dense_shape': (2, 2)}, + 'expected_sequence_length': [1, 1], + 'shape': (2,)}, + {'testcase_name': '3D_with_shape', + 'inputs_args': { + # example 0, ids [[2]] + # example 1, ids [[0, 1], [2]] + 'indices': ((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + 'values': (2., 0., 1., 2.), + 'dense_shape': (2, 2, 2)}, + 'expected_sequence_length': [1, 2], + 'shape': (2,)}, + ) + def test_sequence_length(self, inputs_args, expected_sequence_length, shape): + inputs = sparse_tensor.SparseTensorValue(**inputs_args) + numeric_column = sfc.sequence_numeric_column('aaa', shape=shape) + + _, sequence_length = _get_sequence_dense_tensor( + numeric_column, {'aaa': inputs}) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = _get_sequence_dense_tensor( + numeric_column, {'aaa': sparse_input}) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index cd747df4d69d2c264f5a64b491da9570b1423770..53efae1e10f30a2c5a42c9997c92ad909d77f58e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -66,6 +66,7 @@ tf_custom_op_py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", "//tensorflow/python:smart_cond", + "//tensorflow/python:sort_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -311,17 +312,3 @@ py_test( "//third_party/py/numpy", ], ) - -py_test( - name = "sort_ops_test", - size = "medium", - srcs = ["python/ops/sort_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 95f5ba90aba6ff8d3f1f5b93bde2211ddf1c231b..e72e50585a3861d4527b66f89e1659d76c85960a 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -15,10 +15,6 @@ """Framework utilities. -See the -[Contrib Framework](https://tensorflow.org/api_guides/python/contrib.framework) -guide. - @@assert_same_float_dtype @@assert_scalar @@assert_scalar_int diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index cfdc7df7d8fd4c1406bf447a79038ac33b11e047..00e04b83ac45a83e54eee7a6e4e146fb683c3d98 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -44,17 +44,18 @@ class ExperimentalTest(test.TestCase): # Assert function docs are properly updated. self.assertEqual("_fn", _fn.__name__) - self.assertEqual("fn doc. (experimental)" - "\n" - "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " - "be removed at any time, and without warning." - "\n" - "\nArgs:" - "\n arg0: Arg 0." - "\n arg1: Arg 1." - "\n" - "\nReturns:" - "\n Sum of args.", _fn.__doc__) + self.assertEqual( + "fn doc. (experimental)" + "\n" + "\nWarning: THIS FUNCTION IS EXPERIMENTAL. It may change " + "or be removed at any time, and without warning." + "\n" + "\nArgs:" + "\n arg0: Arg 0." + "\n arg1: Arg 1." + "\n" + "\nReturns:" + "\n Sum of args.", _fn.__doc__) # Assert calling new fn issues log warning. self.assertEqual(3, _fn(1, 2)) diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 1921a77c1e96ee3531d1ed0f98e41c27c9d427ac..42184a4e55e292f7921702e3f8909ae54f717702 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -22,173 +22,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +from tensorflow.python.ops import sort_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops as framework_ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops - - -def sort(values, axis=-1, direction='ASCENDING', name=None): - """Sorts a tensor. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - name: Optional name for the operation. - - Returns: - A `Tensor` with the same dtype and shape as `values`, with the elements - sorted along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - with framework_ops.name_scope(name, 'sort'): - return _sort_or_argsort(values, axis, direction, return_argsort=False) - - -def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): - """Returns the indices of a tensor that give its sorted order along an axis. - - For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to - `tf.sort(values)`. For higher dimensions, the output has the same shape as - `values`, but along the given axis, values represent the index of the sorted - element in that slice of the tensor at the given position. - - Args: - values: 1-D or higher numeric `Tensor`. - axis: The axis along which to sort. The default is -1, which sorts the last - axis. - direction: The direction in which to sort the values (`'ASCENDING'` or - `'DESCENDING'`). - stable: If True, equal elements in the original tensor will not be - re-ordered in the returned order. Unstable sort is not yet implemented, - but will eventually be the default for performance reasons. If you - require a stable order, pass `stable=True` for forwards compatibility. - name: Optional name for the operation. - - Returns: - An int32 `Tensor` with the same shape as `values`. The indices that would - sort each slice of the given `values` along the given `axis`. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - del stable # Unused. - with framework_ops.name_scope(name, 'argsort'): - return _sort_or_argsort(values, axis, direction, return_argsort=True) - - -def _sort_or_argsort(values, axis, direction, return_argsort): - """Internal sort/argsort implementation. - - Args: - values: The input values. - axis: The axis along which to sort. - direction: 'ASCENDING' or 'DESCENDING'. - return_argsort: Whether to return the argsort result. - - Returns: - Either the sorted values, or the indices of the sorted values in the - original tensor. See the `sort` and `argsort` docstrings. - - Raises: - ValueError: If axis is not a constant scalar, or the direction is invalid. - """ - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error - - values = framework_ops.convert_to_tensor(values, name='values') - - return _SORT_IMPL[direction](values, axis_static, return_argsort) - - -def _descending_sort(values, axis, return_argsort=False): - """Sorts values in reverse using `top_k`. - - Args: - values: Tensor of numeric values. - axis: Index of the axis which values should be sorted along. - return_argsort: If False, return the sorted values. If True, return the - indices that would sort the values. - - Returns: - The sorted values. - """ - k = array_ops.shape(values)[axis] - rank = array_ops.rank(values) - static_rank = values.shape.ndims - # Fast path: sorting the last axis. - if axis == -1 or axis + 1 == values.get_shape().ndims: - top_k_input = values - transposition = None - else: - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Calculate the actual axis index if counting from the end. Use the static - # rank if available, or else make the axis back into a tensor. - axis += static_rank or rank - if static_rank is not None: - # Prefer to calculate the transposition array in NumPy and make it a - # constant. - transposition = constant_op.constant( - np.r_[ - # Axes up to axis are unchanged. - np.arange(axis), - # Swap axis and rank - 1. - [static_rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - np.arange(axis + 1, static_rank - 1), - # Swap axis and rank - 1. - [axis]], - name='transposition') - else: - # Generate the transposition array from the tensors. - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - - values, indices = nn_ops.top_k(top_k_input, k) - return_value = indices if return_argsort else values - if transposition is not None: - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return_value = array_ops.transpose(return_value, transposition) - return return_value - - -def _ascending_sort(values, axis, return_argsort=False): - # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis, return_argsort) - # If not argsort, negate the values again. - return values_or_indices if return_argsort else -values_or_indices - - -_SORT_IMPL = { - 'ASCENDING': _ascending_sort, - 'DESCENDING': _descending_sort, -} +sort = sort_ops.sort +argsort = sort_ops.argsort diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 7ee39f304ab213a8fa4e7a6f03cda88037bff9a1..cf5b9d9476738e58f6f1286bf5652d55b49ed4d5 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -114,7 +114,7 @@ def gan_model( discriminator_gen_outputs = discriminator_fn(generated_data, generator_inputs) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) + real_data = _convert_tensor_or_l_or_d(real_data) discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) if check_shapes: @@ -1071,8 +1071,19 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): return get_hooks +def _num_joint_steps(train_steps): + g_steps = train_steps.generator_train_steps + d_steps = train_steps.discriminator_train_steps + # Get the number of each type of step that should be run. + num_d_and_g_steps = min(g_steps, d_steps) + num_g_steps = g_steps - num_d_and_g_steps + num_d_steps = d_steps - num_d_and_g_steps + + return num_d_and_g_steps, num_g_steps, num_d_steps + + def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a hooks function for sequential GAN training. + """Returns a hooks function for joint GAN training. When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON ALL OPTIMIZERS TO AVOID RACE CONDITIONS. @@ -1105,12 +1116,7 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ - g_steps = train_steps.generator_train_steps - d_steps = train_steps.discriminator_train_steps - # Get the number of each type of step that should be run. - num_d_and_g_steps = min(g_steps, d_steps) - num_g_steps = g_steps - num_d_and_g_steps - num_d_steps = d_steps - num_d_and_g_steps + num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps) def get_hooks(train_ops): g_op = train_ops.generator_train_op diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 64d670619905a427a84bee4b661228abca591fae..31d9e827005219bdc07df86d42bef40a38f314f1 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -519,7 +519,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): """Test output type.""" loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.GANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('cyclegan', create_cyclegan_model), @@ -528,7 +528,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): def test_cyclegan_output_type(self, get_gan_model_fn): loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) @parameterized.named_parameters( ('gan', create_gan_model, False), @@ -923,8 +923,7 @@ class GANTrainOpsTest(test.TestCase, parameterized.TestCase): model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. - self.assertEqual(num_trainable_vars, - len(variables_lib.get_trainable_variables())) + self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 94f522c04e5a09ed2d9355fa675125c340407923..fbccbead03fc0d641db40ede661bf3677d44c45d 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -170,6 +170,14 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); + // RendezvousMgr already aborted, shouldn't send RPC call any more + if (!call->status().ok()) { + done(call->status(), Args(), Args(), Tensor(), false); + session()->worker_cache->ReleaseWorker(src_worker, rwi); + delete call; + return; + } + // Start "call". Ref(); call->Start([this, call, src_worker, rwi, done]() { diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index d8584e4e6b7470472b0e1911b9e34c8c80a42e0f..b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -52,9 +52,10 @@ Status GdrServer::Init() { [this](const WorkerEnv* env) { return new GdrRendezvousMgr(env, remote_memory_manager_.get()); }; - WorkerCreationFunction worker_func = [this](WorkerEnv* env) { + WorkerCreationFunction worker_func = [this](WorkerEnv* env, + const ConfigProto& config) { return std::unique_ptr( - new GdrWorker(env, remote_memory_manager_.get())); + new GdrWorker(env, config, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index ce1d8d2d73000559f03046aceacb169890ecc1b6..867cb83f42034c8e9061e333ea671457745f92c3 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -39,9 +39,9 @@ limitations under the License. namespace tensorflow { -GdrWorker::GdrWorker(WorkerEnv* worker_env, +GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) - : GrpcWorker(worker_env), + : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), recv_tensor_recent_request_ids_(100000) {} diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 65105ed997300aa77202301cdd8dddacb0309880..39f11e6bde5a1ca7ae91ead02279d22d70af027b 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -25,7 +25,8 @@ namespace tensorflow { class GdrWorker : public GrpcWorker { public: - GdrWorker(WorkerEnv* env, RemoteMemoryManager* remote_memory_manager); + GdrWorker(WorkerEnv* env, const ConfigProto& config, + RemoteMemoryManager* remote_memory_manager); // Serve the RecvTensorRequest but omit the tensor content and transmit it // out-of-band using GPU Direct RDMA whenever possible. diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD index 9393b702d11a2ef84586f712d30c26fe2a8972bb..2698b83a56a1121fa30f5b05ffa027b4dfd4ba95 100644 --- a/tensorflow/contrib/ignite/BUILD +++ b/tensorflow/contrib/ignite/BUILD @@ -22,48 +22,92 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + ":igfs_ops", ], ) tf_custom_op_library( - name = "_dataset_ops.so", - srcs = ["ops/dataset_ops.cc"], - deps = [":dataset_kernels"], + name = "_ignite_ops.so", + srcs = [ + "kernels/igfs/igfs.h", + "ops/dataset_ops.cc", + "ops/igfs_ops.cc", + ], + deps = [ + ":dataset_kernels", + ":igfs_kernels", + ], ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) +tf_gen_op_libs( + op_lib_names = ["igfs_ops"], + deps = [":igfs_kernels"], +) + cc_library( - name = "dataset_kernels", + name = "ignite_client", srcs = [ - "kernels/ignite_dataset_ops.cc", - "kernels/ignite_client.h", - "kernels/ignite_byte_swapper.h", - "kernels/ignite_plain_client.h", - "kernels/ignite_ssl_wrapper.h", - "kernels/ignite_ssl_wrapper.cc", - "kernels/ignite_binary_object_parser.h", - "kernels/ignite_binary_object_parser.cc", - "kernels/ignite_dataset.h", - "kernels/ignite_dataset.cc", - "kernels/ignite_dataset_iterator.h", - "kernels/ignite_dataset_iterator.cc", + "kernels/client/ignite_client.h", + "kernels/client/ignite_byte_swapper.h", + "kernels/client/ignite_plain_client.h", + "kernels/client/ignite_ssl_wrapper.h", + "kernels/client/ignite_ssl_wrapper.cc", ] + if_not_windows([ - "kernels/ignite_plain_client_unix.cc", + "kernels/client/ignite_plain_client_unix.cc", ]) + if_windows([ - "kernels/ignite_plain_client_windows.cc", + "kernels/client/ignite_plain_client_windows.cc", ]), copts = if_windows([ "-DWIN32_LEAN_AND_MEAN", ]), deps = [ "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", "@boringssl//:ssl", "@protobuf_archive//:protobuf_headers", ], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/dataset/ignite_binary_object_parser.cc", + "kernels/dataset/ignite_binary_object_parser.h", + "kernels/dataset/ignite_dataset.cc", + "kernels/dataset/ignite_dataset.h", + "kernels/dataset/ignite_dataset_iterator.cc", + "kernels/dataset/ignite_dataset_iterator.h", + "kernels/dataset/ignite_dataset_ops.cc", + ], + deps = [ + ":ignite_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "igfs_kernels", + srcs = [ + "kernels/igfs/igfs.cc", + "kernels/igfs/igfs.h", + "kernels/igfs/igfs_client.cc", + "kernels/igfs/igfs_client.h", + "kernels/igfs/igfs_extended_tcp_client.cc", + "kernels/igfs/igfs_extended_tcp_client.h", + "kernels/igfs/igfs_messages.cc", + "kernels/igfs/igfs_messages.h", + "kernels/igfs/igfs_random_access_file.cc", + "kernels/igfs/igfs_random_access_file.h", + "kernels/igfs/igfs_writable_file.cc", + "kernels/igfs/igfs_writable_file.h", + ], + deps = [":ignite_client"], alwayslink = 1, ) @@ -82,10 +126,29 @@ py_library( ], ) +py_library( + name = "igfs_ops", + srcs = [ + "python/ops/igfs_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":igfs_op_loader", + "//tensorflow/python:util", + "//tensorflow/python/data/util:nest", + ], +) + tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "python/ops/gen_dataset_ops.py", - deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"], + deps = [":dataset_ops_op_lib"], +) + +tf_gen_op_wrapper_py( + name = "gen_igfs_ops", + out = "python/ops/gen_igfs_ops.py", + deps = [":igfs_ops_op_lib"], ) tf_kernel_library( @@ -97,13 +160,22 @@ tf_kernel_library( alwayslink = 1, ) +tf_kernel_library( + name = "igfs_ops_kernels", + deps = [ + ":igfs_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "ignite_op_loader", srcs = ["python/ops/ignite_op_loader.py"], - dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"], + dso = [":_ignite_ops.so"], kernels = [ ":dataset_ops_kernels", - "//tensorflow/contrib/ignite:dataset_ops_op_lib", + ":dataset_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ @@ -113,6 +185,22 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "igfs_op_loader", + srcs = ["python/ops/igfs_op_loader.py"], + dso = [":_ignite_ops.so"], + kernels = [ + ":igfs_ops_kernels", + ":igfs_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_igfs_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + # The Apache Ignite servers have to setup before the test and tear down # after the test manually. The docker engine has to be installed. # @@ -122,8 +210,11 @@ tf_custom_op_py_library( # To tear down Apache Ignite servers: # $ bash ./python/tests/stop_ignite.sh tf_py_test( - name = "ignite_dataset_test", - srcs = ["python/tests/ignite_dataset_test.py"], + name = "ignite_test", + srcs = [ + "python/tests/igfs_test.py", + "python/tests/ignite_dataset_test.py", + ], additional_deps = [ ":ignite", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 55c89d27996318dabb29bb15372411005301ebd9..c7db0b77e25668fb8a42d204776044420f403e44 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -1,19 +1,32 @@ -# Ignite Dataset - -- [Overview](#overview) -- [Features](#features) - * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) - * [Structured Objects](#structured-objects) - * [Distributed Training](#distributed-training) - * [SSL Connection](#ssl-connection) - * [Windows Support](#windows-support) -- [Try it out](#try-it-out) -- [Limitations](#limitations) +# Apache Ignite Integration + +- [Overview](#overview) +- [Features](#features) + * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) + * [Structured Objects](#structured-objects) + * [Distributed Training](#distributed-training) + * [Distributed File System](#distributed-file-system) + * [SSL Connection](#ssl-connection) + * [Windows Support](#windows-support) +- [Try it out](#try-it-out) + * [Ignite Dataset](#ignite-dataset) + * [IGFS](#igfs) +- [Limitations](#limitations) ## Overview -[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for -transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow. +[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed +database, caching, and processing platform for transactional, analytical, and +streaming workloads, delivering in-memory speeds at petabyte scale. This contrib +package contains an integration between Apache Ignite and TensorFlow. The +integration is based on +[tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow +side and +[Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) +from Apache Ignite side. It allows to use Apache Ignite as a data source for +neural network training, inference and all other computations supported by +TensorFlow. Another part of this module is an integration with distributed file +system based on Apache Ignite. ## Features @@ -134,6 +147,23 @@ Ignite Dataset allows using these two aspects of distributed neural network trai High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +### Distributed File System + +In addition to database functionality Apache Ignite provides a distributed file +system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS +delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in +addition to its own APIs, IGFS implements Hadoop FileSystem API and can be +transparently plugged into Hadoop or Spark deployments. This contrib package +contains an integration between IGFS and TensorFlow. The integration is based +on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) +from TensorFlow side and +[IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache +Ignite side. It has numerous uses, for example: * Checkpoints of state can be +saved to IGFS for reliability and fault-tolerance. * Training processes +communicate with TensorBoard by writing event files to a directory, which +TensorBoard watches. IGFS allows this communication to work even when +TensorBoard runs in a different process or machine. + ### SSL Connection Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. @@ -141,9 +171,12 @@ Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikip ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite") ->>> ... +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES", + certfile="client.pem", + cert_password="password", + username="ignite", + password="ignite") ``` ### Windows Support @@ -152,7 +185,16 @@ Ignite Dataset is fully compatible with Windows. You can use it as part of Tenso ## Try it out -The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: +Following examples will help you to easily start working with this module. + +### Ignite Dataset + +The simplest way to try Ignite Dataset is to run a +[Docker](https://www.docker.com/) container with Apache Ignite and loaded +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +it using Ignite Dataset. Such container is available on Docker Hub: +[dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). +You need to start this container on your machine: ``` docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist @@ -162,6 +204,35 @@ After that you will be able to work with it following way: ![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +### IGFS + +The simplest way to try IGFS with TensorFlow is to run +[Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS +and then interruct with it using TensorFlow +[tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container +is available on Docker Hub: +[dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). +You need to start this container on your machine: + +``` +docker run -it -p 10500:10500 dmitrievanthony/ignite-with-igfs +``` + +After that you will be able to work with it following way: + +```python +>>> import tensorflow as tf +>>> import tensorflow.contrib.ignite.python.ops.igfs_ops +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='w') as w: +>>> w.write("Hello, world!") +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='r') as r: +>>> print(r.read()) + +Hello, world! +``` + ## Limitations Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure. diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h similarity index 67% rename from tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h index 46df3e39dc4ec6dd4ef5730a184264eaa9fc5872..aac950fcc2aaf016959bbda876ac93df4baea417 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ #include #include "tensorflow/core/platform/byte_order.h" @@ -25,76 +25,75 @@ class ByteSwapper { public: ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; } - inline void SwapIfRequiredInt16(int16_t *x) const { + void SwapIfRequiredInt16(int16_t *x) const { if (swap_) { Swap16(x); } } - inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const { + void SwapIfRequiredUnsignedInt16(uint16_t *x) const { if (swap_) { Swap16(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt32(int32_t *x) const { + void SwapIfRequiredInt32(int32_t *x) const { if (swap_) { Swap32(x); } } - inline void SwapIfRequiredFloat(float *x) const { + void SwapIfRequiredFloat(float *x) const { if (swap_) { Swap32(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt64(int64_t *x) const { + void SwapIfRequiredInt64(int64_t *x) const { if (swap_) { Swap64(x); } } - inline void SwapIfRequiredDouble(double *x) const { + void SwapIfRequiredDouble(double *x) const { if (swap_) { Swap64(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { + void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(&x[i]); } } - inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, - int32_t length) const { + void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { + void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(&x[i]); } } - inline void SwapIfRequiredFloatArr(float *x, int32_t length) const { + void SwapIfRequiredFloatArr(float *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { + void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(&x[i]); } } - inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const { + void SwapIfRequiredDoubleArr(double *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(reinterpret_cast(&x[i])); @@ -102,16 +101,16 @@ class ByteSwapper { } private: - inline void Swap16(int16_t *x) const { + void Swap16(int16_t *x) const { *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF); } - inline void Swap32(int32_t *x) const { + void Swap32(int32_t *x) const { *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) | (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF); } - inline void Swap64(int64_t *x) const { + void Swap64(int64_t *x) const { *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) | (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) | (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) | @@ -123,4 +122,4 @@ class ByteSwapper { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_client.h similarity index 74% rename from tensorflow/contrib/ignite/kernels/ignite_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_client.h index 459b50b48fd95ad105bccaca4076160e0ef152ee..0da80769260d065c4ac6601c0e5cd7050b6b61cb 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -32,44 +32,44 @@ class Client { virtual Status ReadData(uint8_t *buf, const int32_t length) = 0; virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0; - inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); } + Status ReadByte(uint8_t *data) { return ReadData(data, 1); } - inline Status ReadShort(int16_t *data) { + Status ReadShort(int16_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2)); byte_swapper_.SwapIfRequiredInt16(data); return Status::OK(); } - inline Status ReadInt(int32_t *data) { + Status ReadInt(int32_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4)); byte_swapper_.SwapIfRequiredInt32(data); return Status::OK(); } - inline Status ReadLong(int64_t *data) { + Status ReadLong(int64_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8)); byte_swapper_.SwapIfRequiredInt64(data); return Status::OK(); } - inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } + Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } - inline Status WriteShort(const int16_t data) { + Status WriteShort(const int16_t data) { int16_t tmp = data; byte_swapper_.SwapIfRequiredInt16(&tmp); return WriteData((uint8_t *)&tmp, 2); } - inline Status WriteInt(const int32_t data) { + Status WriteInt(const int32_t data) { int32_t tmp = data; byte_swapper_.SwapIfRequiredInt32(&tmp); return WriteData((uint8_t *)&tmp, 4); } - inline Status WriteLong(const int64_t data) { + Status WriteLong(const int64_t data) { int64_t tmp = data; byte_swapper_.SwapIfRequiredInt64(&tmp); return WriteData((uint8_t *)&tmp, 8); @@ -81,4 +81,4 @@ class Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h similarity index 80% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h index 75424c19ee4b7df5378aa23cb41db1752e8d0651..546583246042855d179ebbb18b7dca711063b3f4 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" namespace tensorflow { @@ -40,4 +40,4 @@ class PlainClient : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc index cf672942c61e1239332711db12e62088737c4f41..54efb5b61761708a28dd031b8321ffba9a53ffa9 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc index dad5aace5fabe1df58bb9579bf578f4c35324315..a99a3ada558e51c13ed47eb72911eb5862e71a60 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #define WIN32_LEAN_AND_MEAN #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc index ceb479b0846574a35d86002ebb9c3e8e1d3687ac..8f09c24a3bedda524264f30282a0ad019d515540 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h similarity index 82% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h index 0406644bbaab3de816540ce85e84b489ea9fff12..543e03d1efc3ff186c9db399af18f7aa8ad2c450 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" #include @@ -48,4 +48,4 @@ class SslWrapper : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc similarity index 99% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc index 2c8a7d44b07b43f788bcbc0850b5162cc14dd951..4218ec05f2c3486dd91e2188b674e01d6aadaa2b 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h index eb1f856643a790de6acaa82d4b8ad894fd364376..3e8a1a19623fab3e027db16228e0228e8ec4989a 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ #include -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -78,4 +78,4 @@ enum ObjectType { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc index c4a7d3c513a796c9d95b371bedc609fd75188817..ace96e7b09fcf314757367baed66f622b294e43c 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h similarity index 91% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h index 66bfdf2e2a168e59cd2fec8e2ac5b8fd482d5c15..db3bafb11f2a0047c22ece6d2bc1722afaa5ffdf 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ #include "tensorflow/core/framework/dataset.h" @@ -60,4 +60,4 @@ class IgniteDataset : public DatasetBase { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc index 5da9127aa6a3a4bc16347e6890cc1ba44406c0d5..ce8972f1e7fd59235556cb9514011f0b836077de 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h index c499e2c9ccfac5c15db08c8fd8b26c37aa0404f3..5868c2cb67f9d5c91654db8cf4bb4bbc072fc1ac 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -96,4 +96,4 @@ constexpr int32_t kMinResLength = 12; } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc index f75b1c5ff55ca9ee493148ff79c2edd4b15ac42a..f2108775e29b53765138dcd971bec89d7a10ce40 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/strings/numbers.h" diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae2dbcc2cf5d0ae7e09a26a199dc0c3c80fe22c1 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc @@ -0,0 +1,331 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/file_system_helper.h" + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" + +namespace tensorflow { + +static string GetEnvOrElse(const string &env, string default_value) { + const char *env_c_str = env.c_str(); + return getenv(env_c_str) != nullptr ? getenv(env_c_str) : default_value; +} + +static string MakeRelative(const string &a, const string &b) { + string max = a; + string min = b; + bool first = b.size() > a.size(); + + if (first) { + max = b; + min = a; + } + + auto r = mismatch(min.begin(), min.end(), max.begin()); + return string((first ? r.first : r.second), first ? min.end() : max.end()); +} + +string IGFS::TranslateName(const string &name) const { + StringPiece scheme, namenode, path; + io::ParseURI(name, &scheme, &namenode, &path); + return string(path.data(), path.length()); +} + +IGFS::IGFS() + : host_(GetEnvOrElse("IGFS_HOST", "localhost")), + port_([] { + int port; + if (strings::safe_strto32(GetEnvOrElse("IGFS_PORT", "10500").c_str(), + &port)) { + return port; + } else { + LOG(WARNING) + << "IGFS_PORT environment variable had an invalid value: " + << getenv("IGFS_PORT") << "\nUsing default port 10500."; + return 10500; + } + }()), + fs_name_(GetEnvOrElse("IGFS_FS_NAME", "default_fs")) { + LOG(INFO) << "IGFS created [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +IGFS::~IGFS() { + LOG(INFO) << "IGFS destroyed [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +Status IGFS::NewRandomAccessFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse open_read_response(true); + TF_RETURN_IF_ERROR(client->OpenRead(&open_read_response, path)); + + int64 resource_id = open_read_response.res.stream_id; + result->reset(new IGFSRandomAccessFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New random access file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewWritableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + } + + CtrlResponse open_create_resp(false); + TF_RETURN_IF_ERROR(client->OpenCreate(&open_create_resp, path)); + + int64 resource_id = open_create_resp.res.stream_id; + result->reset(new IGFSWritableFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New writable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewAppendableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, file_name)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, file_name, false)); + } + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client->OpenAppend(&open_append_resp, file_name)); + + result->reset(new IGFSWritableFile(TranslateName(file_name), + open_append_resp.res.stream_id, + std::move(client))); + + LOG(INFO) << "New appendable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewReadOnlyMemoryRegionFromFile( + const string &file_name, std::unique_ptr *result) { + return errors::Unimplemented("IGFS does not support ReadOnlyMemoryRegion"); +} + +Status IGFS::FileExists(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (!exists_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "File exists completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetChildren(const string &file_name, std::vector *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + path = path + "/"; + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_paths_response(false); + TF_RETURN_IF_ERROR(client->ListPaths(&list_paths_response, path)); + + *result = std::vector(); + std::vector entries = list_paths_response.res.entries; + + for (IGFSPath &value : entries) + result->push_back(MakeRelative(value.path, path)); + + LOG(INFO) << "Get children completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetMatchingPaths(const string &pattern, + std::vector *results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + +Status IGFS::DeleteFile(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + + if (!del_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "Delete file completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::CreateDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse mkdir_response(false); + TF_RETURN_IF_ERROR(client->MkDir(&mkdir_response, path)); + + if (!mkdir_response.res.successful) + return errors::Unknown("Can't create directory ", path); + + LOG(INFO) << "Create dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::DeleteDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_files_response(false); + TF_RETURN_IF_ERROR(client->ListFiles(&list_files_response, path)); + + if (!list_files_response.res.entries.empty()) { + return errors::FailedPrecondition("Can't delete a non-empty directory"); + } else { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, true)); + } + + LOG(INFO) << "Delete dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetFileSize(const string &file_name, uint64 *size) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + *size = info_response.res.file_info.length; + + LOG(INFO) << "Get file size completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::RenameFile(const string &src, const string &dst) { + std::unique_ptr client = CreateClient(); + string src_path = TranslateName(src); + string dst_path = TranslateName(dst); + + if (FileExists(dst).ok()) DeleteFile(dst); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse rename_response(false); + TF_RETURN_IF_ERROR(client->Rename(&rename_response, src_path, dst_path)); + + if (!rename_response.res.successful) + return errors::NotFound("File ", src_path, " not found"); + + LOG(INFO) << "Rename file completed successful [src=" << src + << ", dst=" << dst << "]"; + + return Status::OK(); +} + +Status IGFS::Stat(const string &file_name, FileStatistics *stats) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + IGFSFile info = info_response.res.file_info; + + *stats = FileStatistics(info.length, info.modification_time * 1000000, + (info.flags & 0x1) != 0); + + LOG(INFO) << "Stat completed successful [file_name=" << file_name << "]"; + + return Status::OK(); +} + +std::unique_ptr IGFS::CreateClient() const { + return std::unique_ptr( + new IGFSClient(host_, port_, fs_name_, "")); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.h b/tensorflow/contrib/ignite/kernels/igfs/igfs.h new file mode 100644 index 0000000000000000000000000000000000000000..4c347e937f75e8eea108811e6a3189412e22a982 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFS : public FileSystem { + public: + IGFS(); + ~IGFS(); + Status NewRandomAccessFile( + const string& file_name, + std::unique_ptr* result) override; + Status NewWritableFile(const string& fname, + std::unique_ptr* result) override; + Status NewAppendableFile(const string& fname, + std::unique_ptr* result) override; + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, + std::unique_ptr* result) override; + Status FileExists(const string& fname) override; + Status GetChildren(const string& dir, std::vector* result) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + Status DeleteFile(const string& fname) override; + Status CreateDir(const string& name) override; + Status DeleteDir(const string& name) override; + Status GetFileSize(const string& fname, uint64* size) override; + Status RenameFile(const string& src, const string& target) override; + Status Stat(const string& fname, FileStatistics* stat) override; + string TranslateName(const string& name) const override; + + private: + std::unique_ptr CreateClient() const; + + const string host_; + const int port_; + const string fs_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f97c34fdd8b026a04506fd0ef9f3cc74129a9da --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" + +namespace tensorflow { + +IGFSClient::IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name) + : fs_name_(fs_name), + user_name_(user_name), + client_(ExtendedTCPClient(host, port, true)) { + client_.Connect(); +} + +IGFSClient::~IGFSClient() { client_.Disconnect(); } + +Status IGFSClient::SendRequestGetResponse(const Request &request, + Response *response) { + TF_RETURN_IF_ERROR(request.Write(&client_)); + client_.reset(); + + if (response != nullptr) { + TF_RETURN_IF_ERROR(response->Read(&client_)); + client_.reset(); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h new file mode 100644 index 0000000000000000000000000000000000000000..bbec7b000779be8772e850a556affffa1b3b6803 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +class IGFSClient { + public: + IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name); + ~IGFSClient(); + + Status Handshake(CtrlResponse *res) { + return SendRequestGetResponse(HandshakeRequest(fs_name_, {}), res); + } + + Status ListFiles(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListFilesRequest(user_name_, path), res); + } + + Status ListPaths(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListPathsRequest(user_name_, path), res); + } + + Status Info(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(InfoRequest(user_name_, path), res); + } + + Status OpenCreate(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenCreateRequest(user_name_, path), res); + } + + Status OpenAppend(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenAppendRequest(user_name_, path), res); + } + + Status OpenRead(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenReadRequest(user_name_, path), res); + } + + Status Exists(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ExistsRequest(user_name_, path), res); + } + + Status MkDir(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(MakeDirectoriesRequest(user_name_, path), + res); + } + + Status Delete(CtrlResponse *res, const string &path, + bool recursive) { + return SendRequestGetResponse(DeleteRequest(user_name_, path, recursive), + res); + } + + Status WriteBlock(int64_t stream_id, const uint8_t *data, int32_t len) { + return SendRequestGetResponse(WriteBlockRequest(stream_id, data, len), + nullptr); + } + + Status ReadBlock(ReadBlockCtrlResponse *res, int64_t stream_id, int64_t pos, + int32_t length) { + return SendRequestGetResponse(ReadBlockRequest(stream_id, pos, length), + res); + } + + Status Close(CtrlResponse *res, int64_t stream_id) { + return SendRequestGetResponse(CloseRequest(stream_id), res); + } + + Status Rename(CtrlResponse *res, const string &source, + const string &dest) { + return SendRequestGetResponse(RenameRequest(user_name_, source, dest), res); + } + + private: + Status SendRequestGetResponse(const Request &request, Response *response); + + const string fs_name_; + const string user_name_; + ExtendedTCPClient client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea63436546d8b244b921206f9577c91b6578a775 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +ExtendedTCPClient::ExtendedTCPClient(const string &host, int port, + bool big_endian) + : PlainClient(host, port, big_endian), pos_(0) {} + +Status ExtendedTCPClient::ReadData(uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::ReadData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteData(const uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::WriteData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::Ignore(int n) { + uint8_t buf[n]; + return ReadData(buf, n); +} + +Status ExtendedTCPClient::SkipToPos(int target_pos) { + return Ignore(std::max(0, target_pos - pos_)); +} + +Status ExtendedTCPClient::ReadBool(bool *res) { + uint8_t buf = 0; + TF_RETURN_IF_ERROR(ReadData(&buf, 1)); + *res = buf != 0; + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadNullableString(string *res) { + bool is_empty = false; + TF_RETURN_IF_ERROR(ReadBool(&is_empty)); + + if (!is_empty) { + TF_RETURN_IF_ERROR(ReadString(res)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadString(string *res) { + int16_t length; + TF_RETURN_IF_ERROR(ReadShort(&length)); + + uint8_t *buf = new uint8_t[length]; + Status status = ReadData(buf, length); + + if (status.ok()) res->assign(reinterpret_cast(buf), length); + + delete[] buf; + return status; +} + +Status ExtendedTCPClient::ReadStringMap(std::map *res) { + int size; + TF_RETURN_IF_ERROR(ReadInt(&size)); + + for (int i = 0; i < size; i++) { + string key; + string val; + TF_RETURN_IF_ERROR(ReadString(&key)); + TF_RETURN_IF_ERROR(ReadString(&val)); + + res->insert(std::pair(std::move(key), std::move(val))); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteSize(std::map::size_type s) { + return WriteInt(s); +} + +Status ExtendedTCPClient::FillWithZerosUntil(int n) { + int to_skip = std::max(0, n - pos_); + + for (int i = 0; i < to_skip; i++) { + TF_RETURN_IF_ERROR(WriteByte(0)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteBool(bool val) { + return WriteByte((char)(val ? 1 : 0)); +} + +Status ExtendedTCPClient::WriteString(string str) { + if (!str.empty()) { + TF_RETURN_IF_ERROR(WriteBool(false)); + size_t l = str.length(); + if (l > std::numeric_limits::max()) + return errors::InvalidArgument("String is too long"); + + TF_RETURN_IF_ERROR(WriteShort(l)); + TF_RETURN_IF_ERROR(WriteData(reinterpret_cast(str.c_str()), + str.length())); + } else { + TF_RETURN_IF_ERROR(WriteBool(true)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteStringMap(std::map map) { + std::map::size_type size = map.size(); + TF_RETURN_IF_ERROR(WriteSize(size)); + + for (auto &x : map) { + TF_RETURN_IF_ERROR(WriteString(x.first)); + TF_RETURN_IF_ERROR(WriteString(x.second)); + } + + return Status::OK(); +} + +void ExtendedTCPClient::reset() { pos_ = 0; } + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h new file mode 100644 index 0000000000000000000000000000000000000000..c5de342fd0c20cf5b01b647756797631b8a3f203 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" + +namespace tensorflow { + +class ExtendedTCPClient : public PlainClient { + public: + ExtendedTCPClient(const string &host, int port, bool big_endian); + Status ReadData(uint8_t *buf, const int32_t length) override; + Status WriteData(const uint8_t *buf, const int32_t length) override; + Status Ignore(int n); + Status SkipToPos(int target_pos); + Status ReadBool(bool *res); + Status ReadNullableString(string *res); + Status ReadString(string *res); + Status ReadStringMap(std::map *res); + Status WriteSize(std::map::size_type s); + Status FillWithZerosUntil(int n); + Status WriteBool(bool val); + Status WriteString(string str); + Status WriteStringMap(std::map map); + void reset(); + + private: + int pos_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c63f40f35fa53bc51c44f574df50ad0c79fba91 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc @@ -0,0 +1,344 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +Status IGFSPath::Read(ExtendedTCPClient *client) { + return client->ReadNullableString(&path); +} + +Status IGFSFile::Read(ExtendedTCPClient *client) { + int32_t block_size; + int64_t group_block_size; + std::map properties = {}; + int64_t access_time; + + bool has_path; + TF_RETURN_IF_ERROR(client->ReadBool(&has_path)); + if (has_path) { + IGFSPath path = {}; + TF_RETURN_IF_ERROR(path.Read(client)); + } + + TF_RETURN_IF_ERROR(client->ReadInt(&block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&group_block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + TF_RETURN_IF_ERROR(client->ReadStringMap(&properties)); + TF_RETURN_IF_ERROR(client->ReadLong(&access_time)); + TF_RETURN_IF_ERROR(client->ReadLong(&modification_time)); + TF_RETURN_IF_ERROR(client->ReadByte(&flags)); + + return Status::OK(); +} + +Request::Request(int32_t command_id) : command_id_(command_id) {} + +Status Request::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(24)); + + return Status::OK(); +} + +Status Response::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->Ignore(1)); + TF_RETURN_IF_ERROR(client->SkipToPos(8)); + TF_RETURN_IF_ERROR(client->ReadInt(&req_id)); + TF_RETURN_IF_ERROR(client->SkipToPos(24)); + TF_RETURN_IF_ERROR(client->ReadInt(&res_type)); + + bool has_error; + TF_RETURN_IF_ERROR(client->ReadBool(&has_error)); + + if (has_error) { + int32_t error_code; + string error_msg; + TF_RETURN_IF_ERROR(client->ReadString(&error_msg)); + TF_RETURN_IF_ERROR(client->ReadInt(&error_code)); + + return errors::Unknown("Error [code=", error_code, ", message=\"", + error_msg, "\"]"); + } + + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + 5)); + TF_RETURN_IF_ERROR(client->ReadInt(&length)); + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + response_header_size_)); + + return Status::OK(); +} + +PathCtrlRequest::PathCtrlRequest(int32_t command_id_, const string &user_name, + const string &path, + const string &destination_path, bool flag, + bool collocate, + const std::map &properties) + : Request(command_id_), + user_name_(user_name), + path_(path), + destination_path_(destination_path), + flag_(flag), + collocate_(collocate), + props_(properties) {} + +Status PathCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(user_name_)); + TF_RETURN_IF_ERROR(WritePath(client, path_)); + TF_RETURN_IF_ERROR(WritePath(client, destination_path_)); + TF_RETURN_IF_ERROR(client->WriteBool(flag_)); + TF_RETURN_IF_ERROR(client->WriteBool(collocate_)); + TF_RETURN_IF_ERROR(client->WriteStringMap(props_)); + + return Status::OK(); +} + +Status PathCtrlRequest::WritePath(ExtendedTCPClient *client, + const string &path) const { + TF_RETURN_IF_ERROR(client->WriteBool(!path.empty())); + if (!path.empty()) TF_RETURN_IF_ERROR(client->WriteString(path)); + + return Status::OK(); +} + +Status StreamCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->WriteLong(stream_id_)); + TF_RETURN_IF_ERROR(client->WriteInt(length_)); + + return Status::OK(); +} + +StreamCtrlRequest::StreamCtrlRequest(int32_t command_id_, int64_t stream_id, + int32_t length) + : Request(command_id_), stream_id_(stream_id), length_(length) {} + +DeleteRequest::DeleteRequest(const string &user_name, const string &path, + bool flag) + : PathCtrlRequest(DELETE_ID, user_name, path, {}, flag, true, {}) {} + +Status DeleteResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +ExistsRequest::ExistsRequest(const string &user_name, const string &path) + : PathCtrlRequest(EXISTS_ID, user_name, path, {}, false, true, {}) {} + +Status ExistsResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +HandshakeRequest::HandshakeRequest(const string &fs_name, const string &log_dir) + : Request(HANDSHAKE_ID), fs_name_(fs_name), log_dir_(log_dir) {} + +Status HandshakeRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(fs_name_)); + TF_RETURN_IF_ERROR(client->WriteString(log_dir_)); + + return Status::OK(); +} + +Status HandshakeResponse::Read(ExtendedTCPClient *client) { + int64_t block_size; + bool sampling; + + TF_RETURN_IF_ERROR(client->ReadNullableString(&fs_name)); + TF_RETURN_IF_ERROR(client->ReadLong(&block_size)); + + bool has_sampling_; + TF_RETURN_IF_ERROR(client->ReadBool(&has_sampling_)); + + if (has_sampling_) { + TF_RETURN_IF_ERROR(client->ReadBool(&sampling)); + } + + return Status::OK(); +} + +ListRequest::ListRequest(int32_t command_id_, const string &user_name, + const string &path) + : PathCtrlRequest(command_id_, user_name, path, {}, false, true, {}) {} + +ListFilesRequest::ListFilesRequest(const string &user_name, const string &path) + : ListRequest(LIST_FILES_ID, user_name, path) {} + +ListPathsRequest::ListPathsRequest(const string &user_name, const string &path) + : ListRequest(LIST_PATHS_ID, user_name, path) {} + +OpenCreateRequest::OpenCreateRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_CREATE_ID, user_name, path, {}, false, true, {}) {} + +Status OpenCreateRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteInt(replication_)); + TF_RETURN_IF_ERROR(client->WriteLong(blockSize_)); + + return Status::OK(); +} + +Status OpenCreateResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenAppendRequest::OpenAppendRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_APPEND_ID, user_name, path, {}, false, true, {}) {} + +Status OpenAppendRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + return Status::OK(); +} + +Status OpenAppendResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path, + bool flag, + int32_t sequential_reads_before_prefetch) + : PathCtrlRequest(OPEN_READ_ID, user_name, path, {}, flag, true, {}), + sequential_reads_before_prefetch_(sequential_reads_before_prefetch) {} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path) + : OpenReadRequest(user_name, path, false, 0) {} + +Status OpenReadRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + if (flag_) { + TF_RETURN_IF_ERROR(client->WriteInt(sequential_reads_before_prefetch_)); + } + + return Status::OK(); +} + +Status OpenReadResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + + return Status::OK(); +} + +InfoRequest::InfoRequest(const string &user_name, const string &path) + : PathCtrlRequest(INFO_ID, user_name, path, {}, false, true, {}) {} + +Status InfoResponse::Read(ExtendedTCPClient *client) { + file_info = IGFSFile(); + TF_RETURN_IF_ERROR(file_info.Read(client)); + + return Status::OK(); +} + +MakeDirectoriesRequest::MakeDirectoriesRequest(const string &user_name, + const string &path) + : PathCtrlRequest(MKDIR_ID, user_name, path, {}, false, true, {}) {} + +Status MakeDirectoriesResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +CloseRequest::CloseRequest(int64_t streamId) + : StreamCtrlRequest(CLOSE_ID, streamId, 0) {} + +Status CloseResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +ReadBlockRequest::ReadBlockRequest(int64_t stream_id, int64_t pos, + int32_t length) + : StreamCtrlRequest(READ_BLOCK_ID, stream_id, length), pos(pos) {} + +Status ReadBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteLong(pos)); + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client, int32_t length, + uint8_t *dst) { + TF_RETURN_IF_ERROR(client->ReadData(dst, length)); + successfully_read = length; + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client) { + return Status::OK(); +} + +std::streamsize ReadBlockResponse::GetSuccessfullyRead() { + return successfully_read; +} + +ReadBlockCtrlResponse::ReadBlockCtrlResponse(uint8_t *dst) + : CtrlResponse(false), dst(dst) {} + +Status ReadBlockCtrlResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(Response::Read(client)); + + res = ReadBlockResponse(); + TF_RETURN_IF_ERROR(res.Read(client, length, dst)); + + return Status::OK(); +} + +WriteBlockRequest::WriteBlockRequest(int64_t stream_id, const uint8_t *data, + int32_t length) + : StreamCtrlRequest(WRITE_BLOCK_ID, stream_id, length), data(data) {} + +Status WriteBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + TF_RETURN_IF_ERROR(client->WriteData((uint8_t *)data, length_)); + + return Status::OK(); +} + +RenameRequest::RenameRequest(const string &user_name, const string &path, + const string &destination_path) + : PathCtrlRequest(RENAME_ID, user_name, path, destination_path, false, true, + {}) {} + +Status RenameResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h new file mode 100644 index 0000000000000000000000000000000000000000..44a2928a2b2b48849c7ba4454e0e7848c2217b3b --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h @@ -0,0 +1,356 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +enum CommandId { + HANDSHAKE_ID = 0, + EXISTS_ID = 2, + INFO_ID = 3, + RENAME_ID = 6, + DELETE_ID = 7, + MKDIR_ID = 8, + LIST_PATHS_ID = 9, + LIST_FILES_ID = 10, + OPEN_READ_ID = 13, + OPEN_APPEND_ID = 14, + OPEN_CREATE_ID = 15, + CLOSE_ID = 16, + READ_BLOCK_ID = 17, + WRITE_BLOCK_ID = 18, +}; + +class IGFSPath { + public: + Status Read(ExtendedTCPClient *client); + + string path; +}; + +class IGFSFile { + public: + Status Read(ExtendedTCPClient *client); + + int64_t length; + int64_t modification_time; + uint8_t flags; +}; + +class Request { + public: + Request(int32_t command_id); + virtual Status Write(ExtendedTCPClient *client) const; + + protected: + const int32_t command_id_; +}; + +class Response { + public: + virtual Status Read(ExtendedTCPClient *client); + + int32_t res_type; + int32_t req_id; + int32_t length; + + protected: + static const int32_t header_size_ = 24; + static const int32_t response_header_size_ = 9; +}; + +class PathCtrlRequest : public Request { + public: + PathCtrlRequest(int32_t command_id, const string &user_name, + const string &path, const string &destination_path, bool flag, + bool collocate, const std::map &properties); + Status Write(ExtendedTCPClient *client) const override; + + protected: + Status WritePath(ExtendedTCPClient *client, const string &path) const; + + const string user_name_; + const string path_; + const string destination_path_; + const bool flag_; + const bool collocate_; + const std::map props_; +}; + +class StreamCtrlRequest : public Request { + public: + StreamCtrlRequest(int32_t command_id, int64_t stream_id, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + protected: + int64_t stream_id_; + int32_t length_; +}; + +template +class CtrlResponse : public Response { + public: + CtrlResponse(bool optional) : optional_(optional) {} + Status Read(ExtendedTCPClient *client) override { + TF_RETURN_IF_ERROR(Response::Read(client)); + + if (optional_) { + TF_RETURN_IF_ERROR(client->ReadBool(&has_content)); + + if (!has_content) return Status::OK(); + } + + res = R(); + has_content = true; + TF_RETURN_IF_ERROR(res.Read(client)); + + return Status::OK(); + } + + R res; + bool has_content; + + private: + bool optional_; +}; + +template +class ListResponse { + public: + Status Read(ExtendedTCPClient *client) { + int32_t len; + TF_RETURN_IF_ERROR(client->ReadInt(&len)); + + entries.clear(); + + for (int32_t i = 0; i < len; i++) { + T f = {}; + TF_RETURN_IF_ERROR(f.Read(client)); + entries.push_back(f); + } + + return Status::OK(); + } + + std::vector entries; +}; + +class DeleteRequest : public PathCtrlRequest { + public: + DeleteRequest(const string &user_name, const string &path, bool flag); +}; + +class DeleteResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class ExistsRequest : public PathCtrlRequest { + public: + explicit ExistsRequest(const string &user_name, const string &path); +}; + +class ExistsResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class HandshakeRequest : public Request { + public: + HandshakeRequest(const string &fs_name, const string &log_dir); + Status Write(ExtendedTCPClient *client) const override; + + private: + string fs_name_; + string log_dir_; +}; + +class HandshakeResponse { + public: + Status Read(ExtendedTCPClient *client); + + string fs_name; +}; + +class ListRequest : public PathCtrlRequest { + public: + explicit ListRequest(int32_t command_id, const string &user_name, + const string &path); +}; + +class ListFilesRequest : public ListRequest { + public: + ListFilesRequest(const string &user_name, const string &path); +}; + +class ListFilesResponse : public ListResponse {}; + +class ListPathsRequest : public ListRequest { + public: + ListPathsRequest(const string &user_name, const string &path); +}; + +class ListPathsResponse : public ListResponse {}; + +class OpenCreateRequest : public PathCtrlRequest { + public: + OpenCreateRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + private: + int32_t replication_; + int64_t blockSize_; +}; + +class OpenCreateResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenAppendRequest : public PathCtrlRequest { + public: + explicit OpenAppendRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; +}; + +class OpenAppendResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenReadRequest : public PathCtrlRequest { + public: + OpenReadRequest(const string &user_name, const string &path, bool flag, + int32_t seqReadsBeforePrefetch); + OpenReadRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + protected: + /** Sequential reads before prefetch. */ + int32_t sequential_reads_before_prefetch_; +}; + +class OpenReadResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; + int64_t length; +}; + +class InfoRequest : public PathCtrlRequest { + public: + InfoRequest(const string &user_name, const string &path); +}; + +class InfoResponse { + public: + Status Read(ExtendedTCPClient *client); + + IGFSFile file_info; +}; + +class MakeDirectoriesRequest : public PathCtrlRequest { + public: + MakeDirectoriesRequest(const string &userName, const string &path); +}; + +class MakeDirectoriesResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +/** Stream control requests. **/ + +class CloseRequest : public StreamCtrlRequest { + public: + explicit CloseRequest(int64_t stream_id); +}; + +class CloseResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +class ReadBlockRequest : public StreamCtrlRequest { + public: + ReadBlockRequest(int64_t stream_id, int64_t pos, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + int64_t pos; +}; + +class ReadBlockResponse { + public: + Status Read(ExtendedTCPClient *client, int32_t length, uint8_t *dst); + Status Read(ExtendedTCPClient *client); + std::streamsize GetSuccessfullyRead(); + + private: + int32_t length; + std::streamsize successfully_read; +}; + +class ReadBlockCtrlResponse : public CtrlResponse { + public: + ReadBlockCtrlResponse(uint8_t *dst); + Status Read(ExtendedTCPClient *client) override; + + private: + uint8_t *dst; +}; + +class WriteBlockRequest : public StreamCtrlRequest { + public: + WriteBlockRequest(int64_t stream_id, const uint8_t *data, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + const uint8_t *data; +}; + +class RenameRequest : public PathCtrlRequest { + public: + RenameRequest(const string &user_name, const string &path, + const string &destination_path); +}; + +class RenameResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4c898f14e6d298e65f563f4493a822172c40851 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSRandomAccessFile::IGFSRandomAccessFile(const string &file_name, + int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSRandomAccessFile::~IGFSRandomAccessFile() { + CtrlResponse close_response = {false}; + Status status = client_->Close(&close_response, resource_id_); + + if (!status.ok()) LOG(ERROR) << status.ToString(); +} + +Status IGFSRandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const { + ReadBlockCtrlResponse response = ReadBlockCtrlResponse((uint8_t *)scratch); + TF_RETURN_IF_ERROR(client_->ReadBlock(&response, resource_id_, offset, n)); + + std::streamsize sz = response.res.GetSuccessfullyRead(); + if (sz == 0) return errors::OutOfRange("End of file"); + + *result = StringPiece(scratch, sz); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h new file mode 100644 index 0000000000000000000000000000000000000000..b21369ff8a3b19774bcc743f93a5ec4ae1c9b49a --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSRandomAccessFile : public RandomAccessFile { + public: + IGFSRandomAccessFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSRandomAccessFile() override; + Status Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const override; + + private: + const string file_name_; + const int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc new file mode 100644 index 0000000000000000000000000000000000000000..c15ecb7deeb0cf5a8a040e0d1e4b70c732729474 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSWritableFile::IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSWritableFile::~IGFSWritableFile() { + if (resource_id_ >= 0) { + CtrlResponse close_response = {false}; + + Status status = client_->Close(&close_response, resource_id_); + if (!status.ok()) LOG(ERROR) << status.ToString(); + } +} + +Status IGFSWritableFile::Append(StringPiece data) { + return client_->WriteBlock(resource_id_, (uint8_t *)data.data(), data.size()); +} + +Status IGFSWritableFile::Close() { + int64_t resource_to_be_closed = resource_id_; + resource_id_ = -1; + + CtrlResponse close_response = {false}; + return client_->Close(&close_response, resource_to_be_closed); +} + +Status IGFSWritableFile::Flush() { return Sync(); } + +Status IGFSWritableFile::Sync() { + CtrlResponse close_response = {false}; + TF_RETURN_IF_ERROR(client_->Close(&close_response, resource_id_)); + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client_->OpenAppend(&open_append_resp, file_name_)); + + resource_id_ = open_append_resp.res.stream_id; + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h new file mode 100644 index 0000000000000000000000000000000000000000..b406db17e0e350e2cef610bb05c40f658e100140 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSWritableFile : public WritableFile { + public: + IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSWritableFile() override; + Status Append(StringPiece data) override; + Status Close() override; + Status Flush() override; + Status Sync() override; + + private: + const string file_name_; + int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ diff --git a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc b/tensorflow/contrib/ignite/ops/igfs_ops.cc similarity index 62% rename from tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc rename to tensorflow/contrib/ignite/ops/igfs_ops.cc index f3b24b2341e590adfbeac1a18b6a65fbfd34f598..473bddff08b339d3b76a33d40fe34486acdbe151 100644 --- a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc +++ b/tensorflow/contrib/ignite/ops/igfs_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 Google Inc. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" +#include "tensorflow/core/platform/env.h" -namespace tensorflow { -namespace fuzzing { +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" -class FuzzDecodeJpeg : public FuzzStringInputOp { - SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeJpeg); -}; +namespace tensorflow { -STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeJpeg); +REGISTER_FILE_SYSTEM("igfs", IGFS); -} // end namespace fuzzing -} // end namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1d6707d6400a7cd84016150d20973809aca20e --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading IGFS ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/ignite/python/ops/igfs_ops.py b/tensorflow/contrib/ignite/python/ops/igfs_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..12b973b707730f6ba5b057b74a46b27d8f973ede --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_ops.py @@ -0,0 +1,40 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignite File System for checkpointing and communication with TensorBoard. + +Apache Ignite is a memory-centric distributed database, caching, and +processing platform for transactional, analytical, and streaming workloads, +delivering in-memory speeds at petabyte scale. In addition to database +functionality Apache Ignite provides a distributed file system called +IGFS (https://ignite.apache.org/features/igfs.html). IGFS delivers a similar +functionality to Hadoop HDFS, but only in-memory. In fact, in addition to +its own APIs, IGFS implements Hadoop FileSystem API and can be transparently +plugged into Hadoop or Spark deployments. This contrib package contains an +integration between IGFS and TensorFlow. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +file_system_library = os.path.join(resource_loader.get_data_files_path(), + "../../_ignite_ops.so") +load_library.load_file_system_library(file_system_library) diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py index c9af7386cf0a26ed1a950130aa36caa7fb831fd0..e450e2d84ba31a7de925fdb78fc972a592c6ad8c 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py @@ -21,4 +21,4 @@ from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader _dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/signal/python/__init__.py b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh old mode 100644 new mode 100755 similarity index 73% rename from tensorflow/contrib/signal/python/__init__.py rename to tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh index e672d1146c53a813613c9076c0cb6056f7081441..5e39e16c05290f6b5786421670c69a3bd1e27add --- a/tensorflow/contrib/signal/python/__init__.py +++ b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Signal ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-igfs.xml & +sleep 5 # Wait Apache Ignite to be started + +tail -f nohup.out diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d81bf33226cad0d5cc0ea1fb5c5b55672494976 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 127.0.0.1 + + + + + + + + + diff --git a/tensorflow/contrib/ignite/python/tests/igfs_test.py b/tensorflow/contrib/ignite/python/tests/igfs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cacfc568942e20200b7daf10599dde513a4a0a68 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/igfs_test.py @@ -0,0 +1,215 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for IGFS.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.contrib.ignite.python.ops.igfs_ops # pylint: disable=unused-import +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class IGFSTest(test.TestCase): + """The Apache Ignite servers have to setup before the test and tear down + + after the test manually. The docker engine has to be installed. + + To setup Apache Ignite servers: + $ bash start_ignite.sh + + To tear down Apache Ignite servers: + $ bash stop_ignite.sh + """ + + def test_create_file(self): + """Test create file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_create_file/1" + self.assertFalse(gfile.Exists(file_name)) + # Create file. + with gfile.Open(file_name, mode="w") as w: + w.write("") + # Check that file was created. + self.assertTrue(gfile.Exists(file_name)) + + def test_write_read_file(self): + """Test write/read file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_write_read_file/1" + rows = 10000 + self.assertFalse(gfile.Exists(file_name)) + # Write data. + with gfile.Open(file_name, mode="w") as w: + for i in range(rows): + w.write("This is row\n") + # Read data. + with gfile.Open(file_name, mode="r") as r: + lines = r.readlines() + # Check that data is equal. + self.assertEqual(rows, len(lines)) + for i in range(rows): + self.assertEqual("This is row\n", lines[i]) + + def test_delete_recursively(self): + """Test delete recursively. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_delete_recursively/" + file_name = "igfs:///test_delete_recursively/1" + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + gfile.MkDir(dir_name) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(dir_name)) + self.assertTrue(gfile.Exists(file_name)) + # Delete directory recursively. + gfile.DeleteRecursively(dir_name) + # Check that directory was deleted. + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + + def test_copy(self): + """Test copy. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_copy/1" + dst_file_name = "igfs:///test_copy/2" + self.assertFalse(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + # Copy file. + gfile.Copy(src_file_name, dst_file_name) + # Check that files are identical. + self.assertTrue(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_is_directory(self): + """Test is directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_is_directory/1" + file_name = "igfs:///test_is_directory/2" + with gfile.Open(file_name, mode="w") as w: + w.write("") + gfile.MkDir(dir_name) + # Check that directory is a directory. + self.assertTrue(gfile.IsDirectory(dir_name)) + # Check that file is not a directory. + self.assertFalse(gfile.IsDirectory(file_name)) + + def test_list_directory(self): + """Test list directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_list_directory/" + file_names = [ + "igfs:///test_list_directory/1", "igfs:///test_list_directory/2/3" + ] + ch_dir_names = [ + "igfs:///test_list_directory/4", + ] + for file_name in file_names: + with gfile.Open(file_name, mode="w") as w: + w.write("") + for ch_dir_name in ch_dir_names: + gfile.MkDir(ch_dir_name) + ls_expected_result = file_names + ch_dir_names + # Get list of files in directory. + ls_result = gfile.ListDirectory(dir_name) + # Check that list of files is correct. + self.assertEqual(len(ls_expected_result), len(ls_result)) + for e in ["1", "2", "4"]: + self.assertTrue(e in ls_result) + + def test_make_dirs(self): + """Test make dirs. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_make_dirs/" + self.assertFalse(gfile.Exists(dir_name)) + # Make directory. + gfile.MkDir(dir_name) + # Check that directory was created. + self.assertTrue(gfile.Exists(dir_name)) + + def test_remove(self): + """Test remove. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_remove/1" + self.assertFalse(gfile.Exists(file_name)) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(file_name)) + # Remove file. + gfile.Remove(file_name) + # Check that file was removed. + self.assertFalse(gfile.Exists(file_name)) + + def test_rename_file(self): + """Test rename file. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_rename_file/1" + dst_file_name = "igfs:///test_rename_file/2" + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + # Rename file. + gfile.Rename(src_file_name, dst_file_name) + # Check that only new name of file is available. + self.assertFalse(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_rename_dir(self): + """Test rename dir. + + """ + # Setup and check preconditions. + src_dir_name = "igfs:///test_rename_dir/1" + dst_dir_name = "igfs:///test_rename_dir/2" + gfile.MkDir(src_dir_name) + # Rename directory. + gfile.Rename(src_dir_name, dst_dir_name) + # Check that only new name of directory is available. + self.assertFalse(gfile.Exists(src_dir_name)) + self.assertTrue(gfile.Exists(dst_dir_name)) + self.assertTrue(gfile.IsDirectory(dst_dir_name)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh index a67bd44f2fb0d654ba07f022a5070c68df8e2ede..112e0dea844620de600e277bff3685dd7c42c49c 100755 --- a/tensorflow/contrib/ignite/python/tests/start_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh @@ -20,3 +20,7 @@ SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )" # Start Apache Ignite with plain client listener. docker run -itd --name ignite-plain -p 42300:10800 \ -v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh + +# Start Apache Ignite with IGFS. +docker run -itd --name ignite-igfs -p 10500:10500 \ +-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-igfs.sh \ No newline at end of file diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh index 8f03dbd1ede61f548d3de9d9738f97667e75df3c..35b0f32d1b3e1373a231ff23f2b40c8ccc417baf 100755 --- a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh @@ -15,5 +15,4 @@ # ============================================================================== docker rm -f ignite-plain -docker rm -f ignite-ssl -docker rm -f ignite-ssl-auth +docker rm -f ignite-igfs \ No newline at end of file diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index 478b716d88321101c971789f36c0ff8ecd3f418e..108da04494685f06f9afc26a26a5dadcdd99b0ff 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -115,7 +115,7 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, &tranformation_matrix]( + [&input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py index 24b790977dfdb675ff7bf0a119a08e243a30d3aa..ae9c7a611945e1445c933d74b9944054b3f0e0a4 100644 --- a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.image.python.ops import dense_image_warp from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes - +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients @@ -259,7 +259,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase): shape = [1, 2, 1, 1] msg = 'Should have raised an exception for invalid image size' - with self.assertRaises(ValueError, msg=msg): + with self.assertRaises(errors.InvalidArgumentError, msg=msg): self.check_interpolation_correctness(shape, 'float32', 'float32') diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py index 9c7ada7afb7cb620c2e06887795053778f133287..f7ced440720209cb05dfcd79395c51517f9de0d5 100644 --- a/tensorflow/contrib/image/python/ops/dense_image_warp.py +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops @@ -60,28 +61,38 @@ def _interpolate_bilinear(grid, msg = 'Grid must be 4 dimensional. Received size: ' raise ValueError(msg + str(grid.get_shape())) - batch_size, height, width, channels = shape + batch_size, height, width, channels = (array_ops.shape(grid)[0], + array_ops.shape(grid)[1], + array_ops.shape(grid)[2], + array_ops.shape(grid)[3]) + + shape = [batch_size, height, width, channels] query_type = query_points.dtype grid_type = grid.dtype - if (query_points.shape.rank != 3 or - query_points.shape.dims[2].value != 2): - msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' - 'size: ') - raise ValueError(msg + str(query_points.get_shape())) - - _, num_queries, _ = query_points.get_shape().as_list() - - if height < 2 or width < 2: - msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' - raise ValueError(msg + str(grid.get_shape())) - - alphas = [] - floors = [] - ceils = [] - - index_order = [0, 1] if indexing == 'ij' else [1, 0] - unstacked_query_points = array_ops.unstack(query_points, axis=2) + with ops.control_dependencies([ + check_ops.assert_equal( + len(query_points.get_shape()), + 3, + message='Query points must be 3 dimensional.'), + check_ops.assert_equal( + array_ops.shape(query_points)[2], + 2, + message='Query points must be size 2 in dim 2.') + ]): + num_queries = array_ops.shape(query_points)[1] + + with ops.control_dependencies([ + check_ops.assert_greater_equal( + height, 2, message='Grid height must be at least 2.'), + check_ops.assert_greater_equal( + width, 2, message='Grid width must be at least 2.') + ]): + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) for dim in index_order: with ops.name_scope('dim-' + str(dim)): @@ -112,16 +123,18 @@ def _interpolate_bilinear(grid, alpha = array_ops.expand_dims(alpha, 2) alphas.append(alpha) - if batch_size * height * width > np.iinfo(np.int32).max / 8: - error_msg = """The image size or batch size is sufficiently large - that the linearized addresses used by array_ops.gather - may exceed the int32 limit.""" - raise ValueError(error_msg) - - flattened_grid = array_ops.reshape(grid, - [batch_size * height * width, channels]) - batch_offsets = array_ops.reshape( - math_ops.range(batch_size) * height * width, [batch_size, 1]) + with ops.control_dependencies([ + check_ops.assert_less_equal( + math_ops.cast(batch_size * height * width, dtype=dtypes.float32), + np.iinfo(np.int32).max / 8, + message="""The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""") + ]): + flattened_grid = array_ops.reshape( + grid, [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) # This wraps array_ops.gather. We reshape the image data such that the # batch, y, and x coordinates are pulled into the first dimension. @@ -182,7 +195,11 @@ def dense_image_warp(image, flow, name='dense_image_warp'): of dimensions. """ with ops.name_scope(name): - batch_size, height, width, channels = image.get_shape().as_list() + batch_size, height, width, channels = (array_ops.shape(image)[0], + array_ops.shape(image)[1], + array_ops.shape(image)[2], + array_ops.shape(image)[3]) + # The flow is defined on the image grid. Turn the flow into a list of query # points in the grid space. grid_x, grid_y = array_ops.meshgrid( diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index 3327a9f9a613bfb56e6a25af0fe1c0ca18609035..9e19884df852c0fd259a55aef56c62b4189cd1da 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,7 +20,7 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index c8812d4b23f94102d093db878a709b090a3318d6..588f15b867c1fedbadd5a5d945d870a356549468 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,7 +70,10 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "noasan", # TODO(b/119323169) + ], deps = [ ":_typecheck", ":core", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index e6596bfdfb9b153e5946ab7f8933c160cf2f2326..795591ea621dd192e203d4c4c680aebed961f690 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -253,7 +253,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -277,7 +277,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index af8e673f5906ad972408d30f23f2e8ba7e031a00..32f3006b749e3b34572a8d642054c0ec4c4664b0 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -14,10 +14,6 @@ # ============================================================================== """Ops for building neural network layers, regularizers, summaries, etc. -See the -[Contrib Layers](https://tensorflow.org/api_guides/python/contrib.layers) -guide. - @@avg_pool2d @@avg_pool3d @@batch_norm diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 124515e5a6474f2cc1038830346e27277c6ceea7..8015a571e14d0024b0beca700936c21f705b5752 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -691,7 +691,6 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): index += num_val return grouped_vals - @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 6fb4b9ff3534cab34c84de5d13fea7aff756556d..7e6eafaa0d6f60cfc28a4c422abac0b6d5a991fb 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,7 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index d90d6ecf7f671a40a7ff2b066b6782c7421f9887..cab8da808b6413518ff4864cb0b03a42809260f1 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -27,7 +27,7 @@ import numpy as np from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.layers.python.layers import feature_column_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index ac9561c7693fc4ad994a00889aa3f15b4b5a5ee4..403b522ce45ac6ad98a321378626b87aaa7738aa 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -35,6 +35,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import convolutional as convolutional_layers from tensorflow.python.layers import core as core_layers @@ -1958,7 +1959,7 @@ class GDN(base.Layer): self._reparam_offset = reparam_offset self.data_format = data_format self._channel_axis() # trigger ValueError early - self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5) + self.input_spec = input_spec.InputSpec(min_ndim=3, max_ndim=5) def _channel_axis(self): try: @@ -2015,7 +2016,7 @@ class GDN(base.Layer): raise ValueError('The channel dimension of the inputs to `GDN` ' 'must be defined.') self._input_rank = input_shape.ndims - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=input_shape.ndims, axes={ channel_axis: num_channels }) diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 28a6f5aed99b1443ebcc9c391ec332e0febbb04b..7bf2ac62d76d67f0eb131f8f57c5c063955424fa 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -19,9 +19,6 @@ This module and all its submodules are deprecated. See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) for migration instructions. -See the [Contrib Learn](https://tensorflow.org/api_guides/python/contrib.learn) -guide. - @@BaseEstimator @@Estimator @@Trainable diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index eabebb7e881558471c343c0573cc9a8f4a425312..18ca4214a1c407653294ecfac0116bf00cda46a1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -38,11 +37,12 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.training import training_util # The default learning rate of 0.05 is a historical artifact of the initial # implementation, but seems a reasonable choice. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 3d85533d92d17095bae9a69f229171e1bf61ba10..7a3cc8bd984b1b621f50d9dbf2979dcd6fa8b11f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 4e65c180d8bee9ab8fe9b1fbf32edc229c31af09..d46a873bfaa297e7f6242aa56e9d0bf0eb551867 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -36,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 2bd57597c2e9444b51b1dacfbe4180b443c95a3d..ee25cebd484f1e831fe8b6d3aa7290da7558adee 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index e100bc7a1e7be4896e9ab1c965775b5185b38897..439b17e505d1146492a32cc2fd58febee2b2456d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.contrib.linear_optimizer.python import sdca_optimizer -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 597ca4e86dbf66c86182f14a2a364b662d52fb0a..dfc76bfde6c0109f98093232b6f223d6938007f9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -37,7 +37,7 @@ from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec from tensorflow.contrib.linear_optimizer.python import sdca_optimizer as sdca_optimizer_lib from tensorflow.contrib.metrics.python.ops import metric_ops -from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_lib as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor @@ -1745,7 +1745,7 @@ class LinearRegressorTest(test.TestCase): 'place_holder': constant_op.constant([[0.0]] * num_examples), }, constant_op.constant( - [[1 if i % 4 is 0 else 0] for i in range(num_examples)]) + [[1 if i % 4 == 0 else 0] for i in range(num_examples)]) place_holder = feature_column_lib.real_valued_column('place_holder') sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index 647667188238dc18b137eaad98356a79b3a549b4..7a5354222f103aa0f45adc513079e420bbbfd30c 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -524,7 +524,7 @@ class SDCALinearRegressorTest(test.TestCase): # LinearClassifier requires at least one column. 'place_holder': constant_op.constant([[0.0]] * num_examples), - }, constant_op.constant([[1 if i % 4 is 0 else 0] + }, constant_op.constant([[1 if i % 4 == 0 else 0] for i in range(num_examples)]) with self._single_threaded_test_session(): diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 619294b51822bd9983eda777acae5cf0d253926d..d8ac4163b21ce9accceb35f68cf13b0d6b093f9c 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -22,7 +22,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.python.compat import compat from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -67,34 +66,6 @@ def _scale_losses(losses, weights): return math_ops.reduce_sum(reduced_losses) -def _safe_div(numerator, denominator, name="value"): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - name: An optional name for the returned op. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator, name=name) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator), - name=name) - - def _safe_mean(losses, num_present): """Computes a safe mean of the losses. @@ -107,7 +78,7 @@ def _safe_mean(losses, num_present): then zero is returned. """ total_loss = math_ops.reduce_sum(losses) - return _safe_div(total_loss, num_present, name="value") + return math_ops.div_no_nan(total_loss, num_present, name="value") @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @@ -612,14 +583,14 @@ def mean_pairwise_squared_error(predictions, math_ops.square(diffs), reduction_indices=reduction_indices) num_present_per_batch = _num_present(diffs, weights, per_batch=True) - term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch, - name="value") + term1 = 2.0 * math_ops.div_no_nan( + sum_squares_diff_per_batch, num_present_per_batch, name="value") sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch), - name="value") + term2 = 2.0 * math_ops.div_no_nan( + math_ops.square(sum_diff), + math_ops.square(num_present_per_batch), + name="value") loss = _scale_losses(term1 - term2, weights) diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 6c3b02e12b3082be8bfcc316c4c6122931eb5f76..1293e59cbcba86115e99b505b1f0672a01526462 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -142,7 +142,7 @@ First, download and install JetPack for Android version 3.2 or greater from [Nvi git clone https://github.com/tensorflow/tensorflow.git cd tensorflow JETPACK=$HOME/JetPack_Android_3.2 -TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" +TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" ``` #### Building all CUDA-enabled native binaries: diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 0a07588f07f0bb89dbf5dc5909f511f74470fb41..b396c527673902d61072dc9cf7d2766476be8369 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -34,7 +34,7 @@ NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\. # 1.10 branch does not work. `make distclean` fails and blocks the build # process. For now we're hardcoding to the version which is used by # TensorFlow 1.9. -PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" +PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz" # TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.bazel.build. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index eab93f2cc5ed3d5179a58fa717d8b83d0c4d7337..655c7eefcb978d40c8bc16a23685e03ed71bfb63 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -42,6 +42,7 @@ tensorflow/core/kernels/conv_grad_filter_ops.cc tensorflow/core/kernels/conv_grad_input_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/conv_ops.cc +tensorflow/core/kernels/conv_ops_3d.cc tensorflow/core/kernels/conv_ops_fused.cc tensorflow/core/kernels/conv_ops_using_gemm.cc tensorflow/core/kernels/crop_and_resize_op.cc @@ -156,6 +157,7 @@ tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc +tensorflow/core/kernels/multinomial_op.cc tensorflow/core/kernels/no_op.cc tensorflow/core/kernels/non_max_suppression_op.cc tensorflow/core/kernels/one_hot_op.cc @@ -163,6 +165,7 @@ tensorflow/core/kernels/pack_op.cc tensorflow/core/kernels/pad_op.cc tensorflow/core/kernels/padding_fifo_queue.cc tensorflow/core/kernels/padding_fifo_queue_op.cc +tensorflow/core/kernels/pooling_ops_3d.cc tensorflow/core/kernels/pooling_ops_common.cc tensorflow/core/kernels/population_count_op.cc tensorflow/core/kernels/quantization_utils.cc @@ -248,7 +251,9 @@ tensorflow/core/kernels/spectrogram_op.cc tensorflow/core/kernels/split_lib_cpu.cc tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc +tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc +tensorflow/core/kernels/stateless_random_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc tensorflow/core/kernels/strided_slice_op_inst_1.cc diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index ac1236086503a7c6e541bdf098efcb92f84e577f..062deb74b165329d8e72efa73b9d81f4174f8831 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -175,7 +175,7 @@ def f1_score(labels, predictions, weights=None, num_thresholds=200, return best_f1 best_f1 = distribution_strategy_context.get_replica_context().merge_call( - f1_across_replicas, values) + f1_across_replicas, args=(values,)) update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], fn=update_ops['fn'], name='update') diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d6932f6e4b603b1a76250ab622f5fe8eaea81bc9..09fe65b73f8f866a02a5f0c4d7d736973782882a 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -24,7 +24,6 @@ from __future__ import print_function import collections as collections_lib -from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -46,32 +45,6 @@ from tensorflow.python.util.deprecation import deprecated _EPSILON = 1e-7 -def _safe_div(numerator, denominator): - """Computes a safe divide which returns 0 if the denominator is zero. - - Note that the function contains an additional conditional check that is - necessary for avoiding situations where the loss is zero causing NaNs to - creep into the gradient computation. - - Args: - numerator: An arbitrary `Tensor`. - denominator: A `Tensor` whose shape matches `numerator` and whose values are - assumed to be non-negative. - - Returns: - The element-wise value of the numerator divided by the denominator. - """ - if compat.forward_compatible(2018, 11, 1): - return math_ops.div_no_nan(numerator, denominator) - return array_ops.where( - math_ops.greater(denominator, 0), - math_ops.div(numerator, - array_ops.where( - math_ops.equal(denominator, 0), - array_ops.ones_like(denominator), denominator)), - array_ops.zeros_like(numerator)) - - @deprecated(None, 'Please switch to tf.metrics.true_positives. Note that the ' 'order of the labels and predictions arguments has been switched.') def streaming_true_positives(predictions, @@ -3247,24 +3220,20 @@ def streaming_covariance(predictions, # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) # batch_mean_prediction is E[x_B] in the update equation - batch_mean_prediction = _safe_div( - math_ops.reduce_sum(weighted_predictions), - batch_count) - delta_mean_prediction = _safe_div( - (batch_mean_prediction - mean_prediction) * batch_count, - update_count) + batch_mean_prediction = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_predictions), batch_count) + delta_mean_prediction = math_ops.div_no_nan( + (batch_mean_prediction - mean_prediction) * batch_count, update_count) update_mean_prediction = state_ops.assign_add(mean_prediction, delta_mean_prediction) # prev_mean_prediction is E[x_A] in the update equation prev_mean_prediction = update_mean_prediction - delta_mean_prediction # batch_mean_label is E[y_B] in the update equation - batch_mean_label = _safe_div( - math_ops.reduce_sum(weighted_labels), - batch_count) - delta_mean_label = _safe_div( - (batch_mean_label - mean_label) * batch_count, - update_count) + batch_mean_label = math_ops.div_no_nan( + math_ops.reduce_sum(weighted_labels), batch_count) + delta_mean_label = math_ops.div_no_nan( + (batch_mean_label - mean_label) * batch_count, update_count) update_mean_label = state_ops.assign_add(mean_label, delta_mean_label) # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label @@ -3926,9 +3895,8 @@ def cohen_kappa(labels, po_sum = math_ops.reduce_sum(po) total = math_ops.reduce_sum(pe_row) pe_sum = math_ops.reduce_sum( - _safe_div( - math_ops.to_double(pe_row * pe_col), - math_ops.to_double(total))) + math_ops.div_no_nan( + math_ops.to_double(pe_row * pe_col), math_ops.to_double(total))) po_sum, pe_sum, total = (math_ops.to_double(po_sum), math_ops.to_double(pe_sum), math_ops.to_double(total)) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index b313024e2852caf2385454771b289ad0162cc463..45a60d79482787df4564ae3360f8252af93c7a26 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -51,7 +51,7 @@ The pruning library allows for specification of the following hyper parameters: | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | -| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | +| threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py index f0ce6fe03966c2de2dfd8ebcca07bf46afcf4fce..1fa5c8cb485704a5fccc486e823bbc4050bf505a 100644 --- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -119,7 +120,7 @@ class _MaskedConv(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(ndim=self.rank + 2) + self.input_spec = input_spec.InputSpec(ndim=self.rank + 2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) @@ -171,7 +172,7 @@ class _MaskedConv(base.Layer): dtype=self.dtype) else: self.bias = None - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( ndim=self.rank + 2, axes={channel_axis: input_dim}) self.built = True @@ -393,14 +394,14 @@ class MaskedFullyConnected(base.Layer): self.bias_initializer = bias_initializer self.kernel_regularizer = kernel_regularizer self.bias_regularizer = bias_regularizer - self.input_spec = base.InputSpec(min_ndim=2) + self.input_spec = input_spec.InputSpec(min_ndim=2) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') - self.input_spec = base.InputSpec( + self.input_spec = input_spec.InputSpec( min_ndim=2, axes={-1: tensor_shape.dimension_value(input_shape[-1])}) self.kernel = self.add_variable( diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index d2b811641764df05c66654dfcb044fa7e78853a5..f6b4373edd0544555dd16a373802d2feb5d674b1 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -204,7 +204,7 @@ def get_pruning_hparams(): begin_pruning_step=0, end_pruning_step=-1, weight_sparsity_map=[''], - threshold_decay=0.9, + threshold_decay=0.0, pruning_frequency=10, nbins=256, block_height=1, @@ -456,13 +456,14 @@ class Pruning(object): pool_window = [self._block_dim[0], self._block_dim[1]] pool_fn = pruning_utils.factorized_pool - + squeeze_axis = None if not self._spec.use_tpu: pool_fn = nn_ops.pool abs_weights = array_ops.reshape( abs_weights, [1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1]) + squeeze_axis = [0, 3] pooled_weights = pool_fn( abs_weights, @@ -473,7 +474,7 @@ class Pruning(object): name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: - pooled_weights = array_ops.squeeze(pooled_weights) + pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 91b0bb7f6003c047e4dcf342695f433edbc11614..14fc51229ab53a77e8089040e8a8576babd0fafd 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -188,7 +188,6 @@ def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: values = ops.convert_to_tensor(values, name='values') values = array_ops.reshape(values, [-1]) - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -250,7 +249,6 @@ def compute_cdf(values, value_range, **kwargs): name = kwargs.get('name', None) with ops.name_scope(name, 'cdf', [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -336,7 +334,7 @@ def factorized_pool(input_tensor, padding=padding) return array_ops.squeeze( - array_ops.transpose(width_pooling, perm=[0, 1, 3, 2])) + array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]), axis=[0, 1]) def determine_partitioned_axis(partitioned_variable): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 0aca843497611552d922715514118cac003c29b2..d6f2bfcb6c2e2beda912eb538d8a4a0a17b486b3 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -85,8 +85,28 @@ class PruningUtilsTest(test.TestCase): @parameterized.named_parameters( - ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]), - ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1])) + ("Input_32x32_block_1x1", [32, 32], [1, 1]), + # block size 6x6 + ("Input_3x3_block_6x6", [3, 3], [6, 6]), + ("Input_32x32_block_6x6", [32, 32], [6, 6]), + ("Input_2x32_block_6x6", [2, 32], [6, 6]), + ("Input_32x2_block_6x6", [32, 2], [6, 6]), + ("Input_30x30_block_6x6", [30, 30], [6, 6]), + # block size 4x4 + ("Input_32x32_block_4x4", [32, 32], [4, 4]), + ("Input_2x32_block_4x4", [2, 32], [4, 4]), + ("Input_32x2_block_4x4", [32, 2], [4, 4]), + ("Input_30x30_block_4x4", [30, 30], [4, 4]), + # block size 1x4 + ("Input_32x32_block_1x4", [32, 32], [1, 4]), + ("Input_2x32_block_1x4", [2, 32], [1, 4]), + ("Input_32x2_block_1x4", [32, 2], [1, 4]), + ("Input_30x30_block_1x4", [30, 30], [1, 4]), + # block size 4x1 + ("Input_32x32_block_4x1", [32, 32], [4, 1]), + ("Input_2x32_block_4x1", [2, 32], [4, 1]), + ("Input_32x2_block_4x1", [32, 2], [4, 1]), + ("Input_30x30_block_4x1", [30, 30], [4, 1])) class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): def _compare_pooling_methods(self, weights, pooling_kwargs): @@ -97,9 +117,11 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): array_ops.reshape( weights, [1, weights.get_shape()[0], - weights.get_shape()[1], 1]), **pooling_kwargs)) + weights.get_shape()[1], 1]), **pooling_kwargs), + axis=[0, 3]) pooled_weights_factorized_pool = pruning_utils.factorized_pool( weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), pooled_weights_factorized_pool.eval()) @@ -113,8 +135,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val) - def testFactorizedAvgPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedAvgPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "AVG", @@ -123,8 +145,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testFactorizedMaxPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedMaxPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "MAX", @@ -133,8 +155,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testExpandTensor(self, block_dim): - weights = random_ops.random_normal(shape=[1024, 512]) + def testExpandTensor(self, input_shape, block_dim): + weights = random_ops.random_normal(shape=input_shape) self._compare_expand_tensor_with_kronecker_product(weights, block_dim) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index 9ce50bfe1054072b315adecb87f1ba729dfe0d83..b7fd2d2fb9db3eed15eb1cc2934199939790b1c0 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -106,6 +106,32 @@ class MovingAverageOptimizer(optimizer.Optimizer): self._swapped_variable_name_map[v_avg.op.name] = v.op.name return control_flow_ops.group(train_op, ma_op, name='train_with_avg') + def _find_swapped_variable(self, v_name_to_tensor, v_name, tensor): + """Returns name of swapped variable for given tensor. + + Args: + v_name_to_tensor: Mapping from variable names to tensors. + v_name: name of the variable for which swapped variable should be returned + tensor: Tensor which correspond to variable for which swapped variable + should be returned. + + Returns: + Tensor which correspond to swapped variable. + + Raises: + ValueError: If swapped variable could not be found in v_name_to_tensor. + """ + swapped_v_name = self._swapped_variable_name_map.get(v_name, None) + if swapped_v_name is None: + return tensor + else: + if swapped_v_name in v_name_to_tensor: + return v_name_to_tensor[swapped_v_name] + else: + raise ValueError( + ('Variable to swap %s is not part of variables to save. ' + 'This breaks MovingAverageOptimizer.') % swapped_v_name) + def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. @@ -141,33 +167,33 @@ class MovingAverageOptimizer(optimizer.Optimizer): if not isinstance(var_list, dict): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) - # OpListToDict converts variables to tensors. We make sure we can get - # the unique variable name for normal and resource vaiables. - def get_v_name(tensor): - if tensor.op.type == 'ReadVariableOp': - return tensor.op.inputs[0].op.name - else: - return tensor.op.name - v_name_to_tensor = {} - for tensor in six.itervalues(var_list): - v_name = get_v_name(tensor) - v_name_to_tensor[v_name] = tensor + for k, tensor_or_list in six.iteritems(var_list): + # For each partitioned variable OpListToDict returns list of constituent + # parts instead of single tensor. + if (isinstance(tensor_or_list, list) + or isinstance(tensor_or_list, variables.PartitionedVariable)): + for tensor in tensor_or_list: + v_name = tensor.op.name + v_name_to_tensor[v_name] = tensor + else: + v_name_to_tensor[k] = tensor_or_list # Now swap variables and moving averages swapped_var_list = {} - for k, tensor in six.iteritems(var_list): - v_name = get_v_name(tensor) - swapped_v_name = self._swapped_variable_name_map.get(v_name, None) - tensor_to_save = tensor - if swapped_v_name is not None: - if swapped_v_name in v_name_to_tensor: - tensor_to_save = v_name_to_tensor[swapped_v_name] - else: - raise ValueError( - ('Variable to swap %s is not part of variables to save. ' - 'This breaks MovingAverageOptimizer.') % swapped_v_name) - swapped_var_list[k] = tensor_to_save + for k, tensor_or_list in six.iteritems(var_list): + if isinstance(tensor_or_list, list): + tensor_list_to_save = [] + for tensor in tensor_or_list: + v_name = tensor.op.name + swapped_variable = self._find_swapped_variable(v_name_to_tensor, + v_name, + tensor) + tensor_list_to_save.append(swapped_variable) + swapped_var_list[k] = tensor_list_to_save + else: + swapped_var_list[k] = self._find_swapped_variable( + v_name_to_tensor, k, tensor_or_list) # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index f22e7245285a8b2716645f9789eb5997928a22d2..643403eea6f88bcb33aa96d6539bc9a45a109c6b 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -43,97 +45,171 @@ class MovingAverageOptimizerTest(test.TestCase): # Test that MovingAverageOptimizer works with resource variables. self._helpTestRun(use_resource=True) - def _helpTestRun(self, use_resource=False): + def testRunUsePartitionedVars(self): + # Test that MovingAverageOptimizer works with partitioned variables. + self._helpTestRun(use_partitioned_vars=True) + + def testRunUseResourcePartitionedVars(self): + # Test that MovingAverageOptimizer works with resource and partitioned + # variables. + self._helpTestRun(use_partitioned_vars=True, use_resource=True) + + def _helpTestRun(self, use_resource=False, use_partitioned_vars=False): + # Partitioned variables are represented as a "collection" of partitions. + # To simplify the test and reuse as much code as possible we employ + # following test strategy for partitioned variables. + # + # In the case of non-partitioned variables test runs on variables with + # shape [2]. + # + # In the case of partitioned variables we use shape [4] with two partitions, + # thus each partition has shape [2]. + # For partitioned variables the test is run twice (for loop over + # variable_part_names), first time on the first partition of each variable, + # second time on the second partition of each variable. + variable_part_names = ['part_0', 'part_1'] if use_partitioned_vars else [''] for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.session(graph=ops.Graph()) as sess: - orig_val0 = [1.0, 2.0] - orig_val1 = [3.0, 4.0] - var0 = variable_scope.get_variable( - 'var0', - initializer=constant_op.constant(orig_val0, dtype=dtype), - use_resource=use_resource) - var1 = variable_scope.get_variable( - 'var1', - initializer=constant_op.constant(orig_val1, dtype=dtype), - use_resource=use_resource) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - - opt = moving_average_optimizer.MovingAverageOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=2.0), - average_decay=0.5, - sequential_update=sequential_update) - save_dir = tempfile.mkdtemp( - prefix=os.path.join(self.get_temp_dir(), 'run_1')) - save_path = os.path.join(save_dir, 'model') - update = opt.apply_gradients( - list(six.moves.zip([grads0, grads1], [var0, var1]))) - global_vars = variables.global_variables() - ema_var0 = [ - v for v in global_vars - if v.op.name == 'var0/ExponentialMovingAverage' - ][0] - ema_var1 = [ - v for v in global_vars - if v.op.name == 'var1/ExponentialMovingAverage' - ][0] - perturb = control_flow_ops.group([ - state_ops.assign_add(var0, [1.0, 1.0]), - state_ops.assign_add(var1, [2.0, 2.0]), - state_ops.assign_add(ema_var0, [3.0, 3.0]), - state_ops.assign_add(ema_var1, [4.0, 4.0]) - ]) - - # Test that saver with missing ema variables will fail. - with self.assertRaisesRegexp(ValueError, r'Variable to swap'): - opt.swapping_saver(var_list=[var0]) - - train_saver = opt.swapping_saver() - train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) - inference_saver = saver.Saver() - variables.global_variables_initializer().run() - # Step 1. - update.run() - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the swapping saver save/restore operation is identity. - train_saver.save(sess, save_path) - train_saver.restore(sess, save_path) - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the subset saver saves the EMA variable as well. - if sequential_update: - subset_save_path = save_path + '_subset' - train_saver_subset.save(sess, subset_save_path) - perturb.run() - self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) - self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restoring should only restore var0 and ema_var0. - train_saver_subset.restore(sess, subset_save_path) + for var_part_name in variable_part_names: + with self.session(graph=ops.Graph()) as sess: + orig_val0 = [1.0, 2.0] + orig_val1 = [3.0, 4.0] + grads0 = [0.1, 0.1] + grads1 = [0.01, 0.01] + if use_partitioned_vars: + # Use partitioned variables. + # Create partitioned and duplicate each value used as initial + # value of variables. + partitioner = partitioned_variables.fixed_size_partitioner( + num_shards=2) + orig_val0 = orig_val0 * 2 + orig_val1 = orig_val1 * 2 + grads0 = grads0 * 2 + grads1 = grads1 * 2 + else: + # Regular (non-partitioned) variables. + partitioner = None + var0 = variable_scope.get_variable( + 'var0', + initializer=constant_op.constant(orig_val0, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + var1 = variable_scope.get_variable( + 'var1', + initializer=constant_op.constant(orig_val1, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + # Make a fake loss, such that gradient(loss, var0) == grads0 + # and gradient(loss, var1) == grads1 + grads0 = constant_op.constant(grads0, dtype=dtype) + grads1 = constant_op.constant(grads1, dtype=dtype) + loss = (math_ops.reduce_sum(grads0 * var0) + + math_ops.reduce_sum(grads1 * var1)) + + opt = moving_average_optimizer.MovingAverageOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=2.0), + average_decay=0.5, + sequential_update=sequential_update) + save_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'run_1')) + save_path = os.path.join(save_dir, 'model') + + update = opt.minimize(loss) + + # Get variables and their EMAs. In case of partitioned variables + # get proper part of each variable. + def _get_variable(var_name, part_name, ema): + """Returns variable of it's moving average by name.""" + matches = [ + v for v in variables.global_variables() + if ((var_name in v.op.name) + and (part_name in v.op.name) + and (('ExponentialMovingAverage' in v.op.name) == ema)) + ] + self.assertEqual(len(matches), 1) + return matches[0] + var0 = _get_variable('var0', var_part_name, ema=False) + var1 = _get_variable('var1', var_part_name, ema=False) + ema_var0 = _get_variable('var0', var_part_name, ema=True) + ema_var1 = _get_variable('var1', var_part_name, ema=True) + + perturb = control_flow_ops.group([ + state_ops.assign_add(var0, [1.0, 1.0]), + state_ops.assign_add(var1, [2.0, 2.0]), + state_ops.assign_add(ema_var0, [3.0, 3.0]), + state_ops.assign_add(ema_var1, [4.0, 4.0]) + ]) + + # Test that saver with missing ema variables will fail. + with self.assertRaisesRegexp(ValueError, r'Variable to swap'): + opt.swapping_saver(var_list=[var0]) + + train_saver = opt.swapping_saver() + train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) + inference_saver = saver.Saver() + variables.global_variables_initializer().run() + # Step 1. + update.run() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restore back to previous state. + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the swapping saver save/restore operation is identity. + train_saver.save(sess, save_path) train_saver.restore(sess, save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the subset saver saves the EMA variable as well. + if sequential_update: + subset_save_path = save_path + '_subset' + train_saver_subset.save(sess, subset_save_path) + perturb.run() + self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) + self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restoring should only restore var0 and ema_var0. + train_saver_subset.restore(sess, subset_save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restore back to previous state. + train_saver.restore(sess, save_path) - # If updates are parallel, this is not always true after the 1st step. - if sequential_update: + # If updates are parallel, + # this is not always true after the 1st step. + if sequential_update: + # Test that the normal saver will have the averaged variables. + # We test that the average values are between the original value + # and the most recent variable values (since they are an average + # of the two). + val0 = var0.eval() + val1 = var1.eval() + train_saver.save(sess, save_path) + inference_saver.restore(sess, save_path) + avg_val0 = var0.eval() + avg_val1 = var1.eval() + for i in six.moves.range(len(val0)): + self.assertLess(val0[i], avg_val0[i]) + self.assertLess(avg_val0[i], orig_val0[i]) + self.assertLess(val1[i], avg_val1[i]) + self.assertLess(avg_val1[i], orig_val1[i]) + train_saver.restore(sess, save_path) + # Step 2. + update.run() # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value - # and the most recent variable values (since they are an average - # of the two). + # We test that the average values are between the original value and + # the most recent variable values (since they are an average of the + # two). val0 = var0.eval() val1 = var1.eval() + self.assertAllCloseAccordingToType([0.6, 1.6], val0) + self.assertAllCloseAccordingToType([2.96, 3.96], val1) train_saver.save(sess, save_path) inference_saver.restore(sess, save_path) avg_val0 = var0.eval() @@ -143,26 +219,6 @@ class MovingAverageOptimizerTest(test.TestCase): self.assertLess(avg_val0[i], orig_val0[i]) self.assertLess(val1[i], avg_val1[i]) self.assertLess(avg_val1[i], orig_val1[i]) - train_saver.restore(sess, save_path) - # Step 2. - update.run() - # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value and - # the most recent variable values (since they are an average of the - # two). - val0 = var0.eval() - val1 = var1.eval() - self.assertAllCloseAccordingToType([0.6, 1.6], val0) - self.assertAllCloseAccordingToType([2.96, 3.96], val1) - train_saver.save(sess, save_path) - inference_saver.restore(sess, save_path) - avg_val0 = var0.eval() - avg_val1 = var1.eval() - for i in six.moves.range(len(val0)): - self.assertLess(val0[i], avg_val0[i]) - self.assertLess(avg_val0[i], orig_val0[i]) - self.assertLess(val1[i], avg_val1[i]) - self.assertLess(avg_val1[i], orig_val1[i]) def testFailWhenSaverCreatedBeforeInitialized(self): with self.cached_session(): diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 44a8890cb107440b79cf8fbbdfcfda503b1c910f..960826407b66b4efa3c2693efb6d2e17c4b47b33 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -83,14 +84,14 @@ class NadamOptimizer(adam.AdamOptimizer): with ops.control_dependencies([m_t]): m_t = scatter_add(m, indices, m_scaled_g_values) # m_bar = (1 - beta1) * g_t + beta1 * m_t - m_bar = m_scaled_g_values + beta1_t * m_t + m_bar = m_scaled_g_values + beta1_t * array_ops.gather(m_t, indices) # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) v = self.get_slot(var, "v") v_scaled_g_values = (grad * grad) * (1 - beta2_t) v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) with ops.control_dependencies([v_t]): v_t = scatter_add(v, indices, v_scaled_g_values) - v_sqrt = math_ops.sqrt(v_t) - var_update = state_ops.assign_sub( - var, lr * m_bar / (v_sqrt + epsilon_t), use_locking=self._use_locking) + v_t_slice = array_ops.gather(v_t, indices) + v_sqrt = math_ops.sqrt(v_t_slice) + var_update = scatter_add(var, indices, -lr * m_bar / (v_sqrt + epsilon_t)) return control_flow_ops.group(*[var_update, m_bar, v_t]) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index 85e05ce71cec6ef897cadb7d123e630febb3c064..a4372f64874e7591dbceac901fad6c941209bef9 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -52,14 +52,19 @@ def nadam_update_numpy(param, class NadamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): + # need to use a larger value of epsilon here so that + # np.sqrt(v_t) + epsilon doesn't get rounded to 0 when + # the dtype is half and np.sqrt(v_t) = 0, as is the case + # when the gradient is 0 + sparse_epsilon = 1e-7 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0, 0.01], dtype=dtype.as_numpy_dtype) if use_resource: var0 = resource_variable_ops.ResourceVariable(var0_np) @@ -67,21 +72,21 @@ class NadamOptimizerTest(test.TestCase): else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0_np_indices = np.array([0, 2], dtype=np.int32) grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([2])) - grads1_np_indices = np.array([0, 1], dtype=np.int32) + constant_op.constant(grads0_np[grads0_np_indices]), + constant_op.constant(grads0_np_indices), constant_op.constant([3])) + grads1_np_indices = np.array([0, 2], dtype=np.int32) grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([2])) - opt = nadam_optimizer.NadamOptimizer() + constant_op.constant(grads1_np[grads1_np_indices]), + constant_op.constant(grads1_np_indices), constant_op.constant([3])) + opt = nadam_optimizer.NadamOptimizer(epsilon=sparse_epsilon) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 3.0, 4.0], var1.eval()) beta1_power, beta2_power = opt._get_beta_accumulators() @@ -91,8 +96,10 @@ class NadamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) update.run() - var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0) - var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1) + var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0, + epsilon=sparse_epsilon) + var1_np, m1, v1 = nadam_update_numpy(var1_np, grads1_np, t, m1, v1, + epsilon=sparse_epsilon) # Validate updated params self.assertAllCloseAccordingToType(var0_np, var0.eval()) diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 3ba3ee29ec79687df522eb330665a2ce80061682..835fb4aec4f88572cb54d24ca2deae022e277c5c 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -56,6 +56,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:reduce_util", ], ) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f789c83e005ab7ad7e7caff4ef9ee3c2f57c21fe..a72db5e12fc086c3ec817d25d4964bbb9df2db60 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -24,6 +24,7 @@ import abc import six +from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import dtypes @@ -446,7 +447,7 @@ class _OptimizerV2State(object): if v is None: if colocate_with is None: colocate_with = self._non_slot_devices - with self._distribution.colocate_vars_with(colocate_with): + with self._distribution.extended.colocate_vars_with(colocate_with): # TODO(josh11b): Use get_variable() except for the legacy Adam use case. v = variable_scope.variable(initial_value, name=name, trainable=False) self._non_slot_dict[name] = v @@ -790,14 +791,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of replicas (callable-loss case). In this case, # we have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss_value *= 1. / num_replicas + loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: var_list = tape.watched_variables() @@ -808,14 +802,7 @@ class OptimizerV2(optimizer_v1.Optimizer): "be a function when eager execution is enabled.") # Scale loss for number of replicas (non-callable-loss case). - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss *= 1. / num_replicas + loss = self._scale_loss(loss, scale_loss_by_num_replicas) if gate_gradients not in [ optimizer_v1.Optimizer.GATE_NONE, optimizer_v1.Optimizer.GATE_OP, @@ -857,6 +844,19 @@ class OptimizerV2(optimizer_v1.Optimizer): ]) return grads_and_vars + @staticmethod + def _scale_loss(loss_value, scale_loss_by_num_replicas): + """Scale loss for the number of replicas.""" + if scale_loss_by_num_replicas is None: + scale_loss_by_num_replicas = ( + distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN) + if scale_loss_by_num_replicas: + num_replicas = \ + distribute_ctx.get_distribution_strategy().num_replicas_in_sync + if num_replicas > 1: + loss_value *= 1. / num_replicas + return loss_value + def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. @@ -892,7 +892,8 @@ class OptimizerV2(optimizer_v1.Optimizer): raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) return distribute_ctx.get_replica_context().merge_call( - self._distributed_apply, filtered, global_step=global_step, name=name) + self._distributed_apply, args=(filtered,), + kwargs={"global_step": global_step, "name": name}) def _get_or_create_state(self, var_list=None): """Either looks up or creates `_OptimizerV2State`. @@ -927,8 +928,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce( - variable_scope.VariableAggregation.SUM, grads_and_vars) + reduced_grads = distribution.extended.batch_reduce_to( + ds_reduce_util.ReduceOp.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) @@ -944,7 +945,7 @@ class OptimizerV2(optimizer_v1.Optimizer): with ops.name_scope(name, self._name) as name: per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list) # Include the current value of any dynamic hyper parameters in `state`. - non_slot_devices = distribution.non_slot_devices(var_list) + non_slot_devices = distribution.extended.non_slot_devices(var_list) state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access self._hyper, distribution, non_slot_devices) @@ -989,7 +990,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Use the processors to update the variables. update_ops = [] for grad, var in grads_and_vars: - update_ops.extend(distribution.update(var, update, grad, grouped=False)) + update_ops.extend(distribution.extended.update( + var, update, args=(grad,), group=False)) # Give the child class a chance to do something after applying # gradients @@ -1001,8 +1003,8 @@ class OptimizerV2(optimizer_v1.Optimizer): update_ops = control_flow_ops.group(update_ops) with ops.control_dependencies([update_ops]): - finish_updates = distribution.update_non_slot( - non_slot_devices, finish, grouped=False) + finish_updates = distribution.extended.update_non_slot( + non_slot_devices, finish, group=False) # We said grouped=False, which means finish_updates is always a list. # It will be [None] when finish() returns None. if finish_updates == [None]: @@ -1017,8 +1019,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def update_global_step(global_step, name): return global_step.assign_add(1, read_value=False, name=name) - apply_updates = distribution.update(global_step, update_global_step, - name) + apply_updates = distribution.extended.update( + global_step, update_global_step, args=(name,)) # Add the training op to the TRAIN_OP graph collection in graph mode. if not eager_execution: diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index b3f32b8f34e7b956b44bc82322bba16ed6fe43c7..38fcca03116721f3dabfa6d1e7122c369b6b405d 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -50,6 +50,7 @@ tf_kernel_library( prefix = "resampler_ops", deps = [ ":resampler_ops_op_lib", + "//tensorflow/compiler/tf2xla/kernels:resampler_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", ], diff --git a/tensorflow/contrib/resampler/ops/resampler_ops.cc b/tensorflow/contrib/resampler/ops/resampler_ops.cc index 5ab212032e50ace9545762bebda5679f68fbf77c..f785d4ee5fcd63212882ccf736bfc61c35d68545 100644 --- a/tensorflow/contrib/resampler/ops/resampler_ops.cc +++ b/tensorflow/contrib/resampler/ops/resampler_ops.cc @@ -25,7 +25,7 @@ REGISTER_OP("Resampler") .Input("data: T") .Input("warp: T") .Output("output: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle data; ShapeHandle warp; @@ -48,7 +48,7 @@ REGISTER_OP("ResamplerGrad") .Input("grad_output: T") .Output("grad_data: T") .Output("grad_warp: T") - .Attr("T: {half, float, double}") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); c->set_output(1, c->input(1)); diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 391df8cdb4b1c6cd0e22ff2e27527c58abd4c303..e124867415f94fb5052f34f50363ea718d71053b 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -196,6 +196,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/lstm_ops_test.py"], additional_deps = [ ":rnn_py", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 026bf08ced33cf0d663cf0940e8bea3f3f2aca28..cbc8af5350276bf3398cf29a24554fd27e0621ee 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """RNN Cells and additional RNN operations. -See [Contrib RNN](https://tensorflow.org/api_guides/python/contrib.rnn) guide. - @@RNNCell @@LayerRNNCell diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.h b/tensorflow/contrib/rnn/kernels/blas_gemm.h index 9535a76566748eaf8b4756ad0dc26218262ed990..d37210d4b81203287fb633adc309688a35d093bb 100644 --- a/tensorflow/contrib/rnn/kernels/blas_gemm.h +++ b/tensorflow/contrib/rnn/kernels/blas_gemm.h @@ -32,15 +32,26 @@ struct TensorCuBlasGemm { const T* b, int ldb, float beta, T* c, int ldc); }; +template +struct gemm_compute_type { + typedef T type; +}; + +template <> +struct gemm_compute_type { + typedef float type; +}; + template struct TensorBlasGemm; template struct TensorBlasGemm { static void compute(OpKernelContext* ctx, const Device& d, bool transa, - bool transb, float alpha, + bool transb, typename gemm_compute_type::type alpha, typename TTypes::ConstMatrix a, - typename TTypes::ConstMatrix b, float beta, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, typename TTypes::Matrix c) { int64 m = c.dimensions()[0]; int64 n = c.dimensions()[1]; @@ -55,19 +66,23 @@ struct TensorBlasGemm { template struct TensorBlasGemm { static void compute(OpKernelContext* ctx, const Device& d, bool transa, - bool transb, T alpha, typename TTypes::ConstMatrix a, - typename TTypes::ConstMatrix b, T beta, + bool transb, typename gemm_compute_type::type alpha, + typename TTypes::ConstMatrix a, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, typename TTypes::Matrix c) { Eigen::array, 1> contract_pairs; contract_pairs[0] = Eigen::IndexPair(transa == false, transb == true); - if (alpha == T(1) && beta == T(0)) { + if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(0.f)) { c.device(d) = a.contract(b, contract_pairs); - } else if (alpha == T(1) && beta == T(1)) { + } else if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(1.f)) { c.device(d) += a.contract(b, contract_pairs); } else { - c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) + - c.constant(beta) * c; + c.device(d) = c.constant(T(alpha)) * a.contract(b, contract_pairs) + + c.constant(T(beta)) * c; } } }; diff --git a/tensorflow/contrib/rnn/kernels/gru_ops.h b/tensorflow/contrib/rnn/kernels/gru_ops.h index 3e2cb39e64bb3f0b22ea66c5601af36c5fb9b0fd..38be58fa104f8b30e4aede6d18330960fc30dcb5 100644 --- a/tensorflow/contrib/rnn/kernels/gru_ops.h +++ b/tensorflow/contrib/rnn/kernels/gru_ops.h @@ -88,7 +88,9 @@ struct GRUBlockCellFprop : public GRUCell { typename TTypes::ConstMatrix const_x_h_prev(x_h_prev.data(), x_h_prev.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_x_h_prev, w_ru, T(0), r_u_bar); + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prev, w_ru, typename gemm_compute_type::type(0.f), + r_u_bar); // Creating a bias matrix for adding by broadcasting 'b_ru' Eigen::array broadcast_shape({batch_size_, 1}); @@ -107,7 +109,8 @@ struct GRUBlockCellFprop : public GRUCell { typename TTypes::ConstMatrix const_x_h_prevr(x_h_prevr.data(), x_h_prevr.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_x_h_prevr, w_c, T(0), c); + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prevr, w_c, typename gemm_compute_type::type(0.f), c); Eigen::array b_c_shape({1, b_c.dimensions()[0]}); c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape)); @@ -148,9 +151,10 @@ struct GRUBlockCellBprop : public GRUCell { // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T typename TTypes::ConstMatrix const_d_c_bar(d_c_bar.data(), d_c_bar.dimensions()); - TensorBlasGemm::compute(ctx, d, false, true, T(1), - const_d_c_bar, w_c, T(0), - d_x_comp2_and_h_prevr); + TensorBlasGemm::compute( + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_c_bar, w_c, typename gemm_compute_type::type(0.f), + d_x_comp2_and_h_prevr); d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends()); d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r); @@ -164,7 +168,8 @@ struct GRUBlockCellBprop : public GRUCell { typename TTypes::ConstMatrix const_d_r_bar_u_bar( d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions()); TensorBlasGemm::compute( - ctx, d, false, true, T(1), const_d_r_bar_u_bar, w_ru, T(0), + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_r_bar_u_bar, w_ru, typename gemm_compute_type::type(0.f), d_x_comp1_and_h_prev_comp1); // d_x = d_x_comp1 + d_x_comp2 diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index ee08d306f84baaba8b774ce3fa1a04d5f9a4f6dd..d369bc12ae88dafb4e3ca0095a08bcc3ee09bf70 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -61,7 +61,8 @@ void LSTMBlockCellFpropWithEigen( // states1 = xh * w + b typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + ctx, d, false, false, typename gemm_compute_type::type(1.f), const_xh, + w, typename gemm_compute_type::type(0.f), icfo); Eigen::array b_shape({1, b.dimensions()[0]}); Eigen::array broadcast_shape({cell.batch_size(), 1}); icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); @@ -87,11 +88,11 @@ void LSTMBlockCellFpropWithEigen( if (use_peephole) { auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias) + f_peep) + f.constant(T(forget_bias)) + f_peep) .sigmoid(); } else { f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + - f.constant(forget_bias)) + f.constant(T(forget_bias))) .sigmoid(); } @@ -100,7 +101,7 @@ void LSTMBlockCellFpropWithEigen( if (cell_clip > 0.0f) { cs.device(d) = - cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); + cs.binaryExpr(cs.constant(T(cell_clip)), Eigen::scalar_clip_op()); } // co = tanh(cs) @@ -225,6 +226,7 @@ void LSTMBlockCellBpropWithEigen( template struct LSTMBlockCellBprop; DEFINE_CPU_SPECS(float); +DEFINE_CPU_SPECS(Eigen::half); #undef DEFINE_CPU_SPECS } // namespace functor @@ -373,7 +375,7 @@ class LSTMBlockCellOp : public OpKernel { Name("LSTMBlockCell").Device(DEVICE_CPU).TypeConstraint("T"), \ LSTMBlockCellOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -398,7 +400,6 @@ namespace functor { DECLARE_GPU_SPEC(float); DECLARE_GPU_SPEC(Eigen::half); -// DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // end namespace functor @@ -661,7 +662,7 @@ class LSTMBlockCellGradOp : public OpKernel { Name("LSTMBlockCellGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ LSTMBlockCellGradOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -1008,7 +1009,7 @@ class BlockLSTMOp : public OpKernel { Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint("T"), \ BlockLSTMOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA @@ -1283,7 +1284,7 @@ class BlockLSTMGradOp : public OpKernel { Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ BlockLSTMGradOp); REGISTER_KERNEL(float); -// REGISTER_KERNEL(double); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL #if GOOGLE_CUDA diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index b664b0f45ee08648e4dc10e8244340df1615ad19..15ae95f13cffa5d1469d737b23f2a83b9e5a694f 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, // const int gid = batch_id * cell_size * 4 + act_id; const int cid = batch_id * cell_size + act_id; - Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_logistic_op sigmoid_op; Eigen::internal::scalar_tanh_op tanh_op; Eigen::scalar_clip_op clip_op; @@ -169,7 +169,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, f[cid] = f_local; T cs_local = i_local * ci_local + f_local * cs_prev[cid]; - if (cell_clip_t > strict_cast(0.0f)) { + if (cell_clip > 0.0f) { cs_local = clip_op(cs_local, cell_clip_t); } cs[cid] = cs_local; @@ -248,7 +248,8 @@ void LSTMBlockCellFpropWithCUDA( // states1 = xh * w typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); TensorBlasGemm::compute( - ctx, d, false, false, 1.f, const_xh, w, 0.f, icfo); + ctx, d, false, false, typename gemm_compute_type::type(1.f), const_xh, + w, typename gemm_compute_type::type(0.f), icfo); // Add bias, apply non-linearities and gating. // diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 9ce0b399ba173b67285e907a050c71af5d57068c..d5700d2a200f6cdac06183366c0d11ec3531235b 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.rnn.python.kernel_tests import benchmarking @@ -27,6 +28,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_bitwise_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import rnn @@ -38,7 +41,70 @@ from tensorflow.python.platform import test block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access -def blocks_match(sess, use_peephole): +class _MaskedRandomUniformInitializer(init_ops.RandomUniform): + """Initializer for uniform dist tensors with trailing bits zeroed-out. + + Allow returning tensors with last few mantissa bits set to 0. This potentially + helps avoid getting into precision issues when testing low precision (float16) + computation. + """ + + def __init__(self, + minval=0, + maxval=None, + seed=None, + dtype=dtypes.float16, + num_valid_mantissa_bits=4): + """Constructor. + + Args: + minval: A python scalar or a scalar tensor. Lower bound of the range of + random values to generate. + maxval: A python scalar or a scalar tensor. Upper bound of the range of + random values to generate. Defaults to 1 for float types. + seed: A Python integer. Used to create random seeds. See + `tf.set_random_seed` for behavior. + dtype: The data type. Only supports tf.float16 for now. + num_valid_mantissa_bits: number of non-zero mantissa bits, default to 4. + + Raises: + ValueError: An error if `dtype` is not tf.float16. + """ + if dtype not in (dtypes.float16,): + raise ValueError("dtype: %s not supported" % dtype.name) + + super(_MaskedRandomUniformInitializer, self).__init__( + minval=minval, maxval=maxval, seed=seed, dtype=dtype) + self._num_mantissa_bits = 10 + self._num_valid_mantissa_bits = num_valid_mantissa_bits + + def __call__(self, shape, dtype=dtypes.float16, partition_info=None): + if dtype and dtype != dtypes.float16: + raise ValueError("dtype: %s not supported" % dtype.name) + res = super(_MaskedRandomUniformInitializer, self).__call__( + shape, dtype, partition_info) + # get uint16 view of the underlying buffer. + res = gen_array_ops.bitcast(res, dtypes.uint16) + + # mask the last `shift` mantissa bits. + shift = self._num_mantissa_bits - self._num_valid_mantissa_bits + mask = (0xffff >> shift) << shift + res = gen_bitwise_ops.bitwise_and(res, mask) + + # restore float16 view. + return gen_array_ops.bitcast(res, dtype) + + +def _get_initializer(init_bound, dtype, seed): + if dtype == dtypes.float16: + return _MaskedRandomUniformInitializer( + -init_bound, init_bound, dtype=dtype, seed=seed) + else: + return init_ops.random_uniform_initializer( + -init_bound, init_bound, dtype=dtype, seed=seed) + + +def blocks_match(sess, use_peephole, dtype=dtypes.float32, cell_clip=None): batch_size = 2 input_size = 3 cell_size = 4 @@ -47,36 +113,42 @@ def blocks_match(sess, use_peephole): inputs = [] for _ in range(sequence_length): inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) + np.random.randn(batch_size, input_size), dtype=dtype) inputs.append(inp) stacked_inputs = array_ops.stack(inputs) - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) + init_bound = 1e-1 if dtype == dtypes.float16 else 1e-2 + initializer = _get_initializer(init_bound, dtype=dtype, seed=19890212) with variable_scope.variable_scope("test", initializer=initializer): # magic naming so that the cells pick up these variables and reuse them if use_peephole: wci = variable_scope.get_variable( - "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtype) wcf = variable_scope.get_variable( - "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtype) wco = variable_scope.get_variable( - "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32) + "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtype) w = variable_scope.get_variable( "rnn/lstm_cell/kernel", shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) + dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[cell_size * 4], - dtype=dtypes.float32, + dtype=dtype, initializer=init_ops.zeros_initializer()) basic_cell = rnn_cell.LSTMCell( - cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True) + cell_size, + use_peepholes=use_peephole, + cell_clip=cell_clip, + dtype=dtype, + state_is_tuple=True, + reuse=True) basic_outputs_op, basic_state_op = rnn.static_rnn( - basic_cell, inputs, dtype=dtypes.float32) + basic_cell, inputs, dtype=dtype) if use_peephole: _, _, _, _, _, _, block_outputs_op = block_lstm( @@ -87,7 +159,7 @@ def blocks_match(sess, use_peephole): wci=wci, wcf=wcf, wco=wco, - cell_clip=0, + cell_clip=cell_clip, use_peephole=True) else: _, _, _, _, _, _, block_outputs_op = block_lstm( @@ -95,13 +167,15 @@ def blocks_match(sess, use_peephole): inputs, w, b, - cell_clip=0) + cell_clip=cell_clip) fused_cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=use_peephole, reuse=True, + cell_size, + cell_clip=cell_clip, + use_peephole=use_peephole, + reuse=True, name="rnn/lstm_cell") - fused_outputs_op, fused_state_op = fused_cell( - stacked_inputs, dtype=dtypes.float32) + fused_outputs_op, fused_state_op = fused_cell(stacked_inputs, dtype=dtype) sess.run([variables.global_variables_initializer()]) basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) @@ -127,7 +201,19 @@ def blocks_match(sess, use_peephole): block_wgrads, fused_wgrads) -class LSTMBlockCellTest(test.TestCase): +class LSTMBlockCellTest(test.TestCase, parameterized.TestCase): + + TEST_CASES = ({ + "testcase_name": "Fp32", + "dtype": dtypes.float32, + "rtol": 1e-6, + "atol": 1e-6 + }, { + "testcase_name": "Fp16", + "dtype": dtypes.float16, + "rtol": 8e-3, + "atol": 8e-4 + }) def testNoneDimsWithDynamicRNN(self): with self.session(use_gpu=True, graph=ops.Graph()) as sess: @@ -314,41 +400,43 @@ class LSTMBlockCellTest(test.TestCase): for basic, block in zip(basic_res, block_res): self.assertAllClose(basic, block) - def testLSTMBasicToBlock(self): - with self.session(use_gpu=True) as sess: + def LSTMBasicToBlockTestHelper(self, + dtype=dtypes.float32, + use_peephole=False, + cell_clip=None, + rtol=1e-6, + atol=1e-6): + with self.session(use_gpu=True, graph=ops.Graph()) as sess: (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, fused_wgrads) = blocks_match( - sess, use_peephole=False) + sess, use_peephole=use_peephole, dtype=dtype, cell_clip=cell_clip) - self.assertAllClose(basic_outputs, block_outputs) - self.assertAllClose(basic_grads, block_grads) + self.assertAllClose(basic_outputs, block_outputs, rtol=rtol, atol=atol) + self.assertAllClose(basic_grads, block_grads, rtol=rtol, atol=atol) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) + self.assertAllClose(basic, block, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(block_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) + self.assertAllClose(basic_outputs, fused_outputs, rtol=rtol, atol=atol) + self.assertAllClose(basic_state, fused_state, rtol=rtol, atol=atol) + self.assertAllClose(basic_grads, fused_grads, rtol=rtol, atol=atol) + for basic, fused in zip(basic_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=rtol, atol=atol) - def testLSTMBasicToBlockPeeping(self): - with self.session(use_gpu=True) as sess: - (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, - basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, - fused_wgrads) = blocks_match( - sess, use_peephole=True) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlock(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=False, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, block_outputs) - self.assertAllClose(basic_grads, block_grads) - for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlockPeeping(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=True, rtol=rtol, atol=atol) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(block_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) + @parameterized.named_parameters(*TEST_CASES) + def testLSTMBasicToBlockCellClip(self, dtype, rtol, atol): + self.LSTMBasicToBlockTestHelper( + dtype, use_peephole=True, cell_clip=0.5, rtol=rtol, atol=atol) def testLSTMFusedSequenceLengths(self): """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" @@ -444,16 +532,21 @@ class BenchmarkLSTMBlock(test.Benchmark): "batch_size": [1, 8, 13, 32, 67, 128], "cell_size": [128, 250, 512, 650, 1024, 1350], "time_steps": [40], - "use_gpu": [True, False] + "use_gpu": [True, False], + "dtype": ["float32", "float16"], }): + dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 with ops.Graph().as_default(): with benchmarking.device(use_gpu=config["use_gpu"]): inputs = variable_scope.get_variable( "x", - [config["time_steps"], config["batch_size"], config["cell_size"]]) - cell = lstm_ops.LSTMBlockCell(config["cell_size"]) - outputs = rnn.dynamic_rnn( - cell, inputs, time_major=True, dtype=dtypes.float32) + dtype=dtype, + shape=[ + config["time_steps"], config["batch_size"], + config["cell_size"] + ]) + cell = lstm_ops.LSTMBlockCell(config["cell_size"], dtype=dtype) + outputs = rnn.dynamic_rnn(cell, inputs, time_major=True, dtype=dtype) init_op = variables.global_variables_initializer() with session.Session() as sess: @@ -464,12 +557,14 @@ class BenchmarkLSTMBlock(test.Benchmark): # is set, this will produce a copy-paste-able CSV file. print(",".join( map(str, [ - config["batch_size"], config["cell_size"], config["cell_size"], - config["time_steps"], config["use_gpu"], wall_time + config["dtype"], config["batch_size"], config["cell_size"], + config["cell_size"], config["time_steps"], config["use_gpu"], + wall_time ]))) benchmark_name_template = "_".join([ - "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i", - "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + "LSTMBlockCell_fprop", "DT_%(dtype)s", "BS%(batch_size)i", + "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", + "gpu_%(use_gpu)s" ]) self.report_benchmark( @@ -488,8 +583,10 @@ class BenchmarkLSTMBlock(test.Benchmark): "batch_size": [1, 8, 13, 32, 67, 128], "cell_size": [128, 250, 512, 650, 1024, 1350], "time_steps": [40], - "use_gpu": [True, False] + "use_gpu": [True, False], + "dtype": ["float32", "float16"], }): + dtype = dtypes.float32 if config["dtype"] == "float32" else dtypes.float16 with ops.Graph().as_default(): with benchmarking.device(use_gpu=config["use_gpu"]): time_steps = config["time_steps"] @@ -498,21 +595,21 @@ class BenchmarkLSTMBlock(test.Benchmark): inputs = variable_scope.get_variable( "x", [time_steps, batch_size, cell_size], trainable=False, - dtype=dtypes.float32) + dtype=dtype) with variable_scope.variable_scope( "rnn", reuse=variable_scope.AUTO_REUSE): w = variable_scope.get_variable( "rnn/lstm_cell/kernel", shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) + dtype=dtype) b = variable_scope.get_variable( "rnn/lstm_cell/bias", shape=[cell_size * 4], - dtype=dtypes.float32, + dtype=dtype, initializer=init_ops.zeros_initializer()) - cell = lstm_ops.LSTMBlockCell(cell_size) + cell = lstm_ops.LSTMBlockCell(cell_size, dtype=dtype) outputs = rnn.dynamic_rnn( - cell, inputs, time_major=True, dtype=dtypes.float32) + cell, inputs, time_major=True, dtype=dtype) grads = gradients_impl.gradients(outputs, [inputs, w, b]) init_op = variables.global_variables_initializer() @@ -524,12 +621,13 @@ class BenchmarkLSTMBlock(test.Benchmark): # is set, this will produce a copy-paste-able CSV file. print(",".join( map(str, [ - batch_size, cell_size, cell_size, time_steps, config["use_gpu"], - wall_time + config["dtype"], batch_size, cell_size, cell_size, time_steps, + config["use_gpu"], wall_time ]))) benchmark_name_template = "_".join([ - "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i", - "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + "LSTMBlockCell_bprop", "DT_%(dtype)s", "BS%(batch_size)i", + "CS%(cell_size)i", "IS%(cell_size)i", "TS%(time_steps)i", + "gpu_%(use_gpu)s" ]) self.report_benchmark( diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index b30ca7882fce1747cb1dcb27f97f5b012ff9da02..251a933eaec826b08266123245d9aef8573d3e06 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -21,7 +21,7 @@ from tensorflow.contrib.rnn.ops import gen_gru_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -165,7 +165,7 @@ class GRUBlockCell(LayerRNNCell): num_units = cell_size self._cell_size = num_units # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 2f65c73470110922081166e067443f9e7a6c0596..b043026bc556a8879b15b432829baf8136250c0e 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.rnn.ops import gen_lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -156,7 +157,7 @@ def _block_lstm(seq_len_max, Args: seq_len_max: A `Tensor` of type `int64`. - x: A list of at least 1 `Tensor` objects of the same type in: `float32`. + x: A list of at least 1 `Tensor` objects of the same type. w: A `Tensor`. Must have the same type as `x`. b: A `Tensor`. Must have the same type as `x`. cs_prev: A `Tensor`. Must have the same type as `x`. @@ -189,6 +190,7 @@ def _block_lstm(seq_len_max, Raises: ValueError: If `b` does not have a valid shape. """ + dtype = x[0].dtype batch_size = x[0].get_shape().with_rank(2).dims[0].value cell_size4 = b.get_shape().with_rank(1).dims[0].value if cell_size4 is None: @@ -197,13 +199,13 @@ def _block_lstm(seq_len_max, zero_state = None if cs_prev is None or h_prev is None: zero_state = array_ops.constant( - 0, dtype=dtypes.float32, shape=[batch_size, cell_size]) + 0, dtype=dtype, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: h_prev = zero_state if wci is None: - wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) + wci = array_ops.constant(0, dtype=dtype, shape=[cell_size]) wcf = wci wco = wci @@ -384,7 +386,7 @@ class LSTMBlockCell(LayerRNNCell): "scope": "lstm_cell" } # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -627,7 +629,7 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): self._use_peephole = use_peephole # Inputs must be 3-dimensional. - self.input_spec = base_layer.InputSpec(ndim=3) + self.input_spec = input_spec.InputSpec(ndim=3) @property def num_units(self): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index e159dc95796e8f02287a4b6db4d25023348fe8da..8a1c09f171e6108174671e3122d5ff4c0b236003 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import initializers -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras.engine import input_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_array_ops @@ -2752,7 +2752,7 @@ class SRUCell(rnn_cell_impl.LayerRNNCell): self._activation = activation or math_ops.tanh # Restrict inputs to be 2-dimensional matrices - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) @property def state_size(self): @@ -3089,7 +3089,7 @@ class IndRNNCell(rnn_cell_impl.LayerRNNCell): super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3183,7 +3183,7 @@ class IndyGRUCell(rnn_cell_impl.LayerRNNCell): super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._activation = activation or math_ops.tanh @@ -3323,7 +3323,7 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self._num_units = num_units self._forget_bias = forget_bias @@ -3444,7 +3444,7 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) @@ -3558,7 +3558,7 @@ class CFNCell(rnn_cell_impl.LayerRNNCell): super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. - self.input_spec = base_layer.InputSpec(ndim=2) + self.input_spec = input_spec.InputSpec(ndim=2) self.units = units self.activation = activations.get(activation) diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 291ff83791c7cded2dccc4719bb12e84f00afa42..269443b2c6508bb618d30f64487b1a6a84e8646f 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,7 +82,6 @@ py_library( name = "keras_saved_model", srcs = ["python/saved_model/keras_saved_model.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ "//tensorflow/python:array_ops", @@ -103,7 +102,10 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = [ + "no_oss", # TODO(b/119349471): Re-enable + "no_windows", + ], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py index 6aae4bc5e2981ca4e36e434e577a35a5ac531bba..4c8db94d6f48749d880da284d18aa5a7879b1494 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -19,17 +19,18 @@ from __future__ import division from __future__ import print_function import os +import six from tensorflow.python.client import session from tensorflow.python.estimator import keras as estimator_keras_util from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export as export_helpers -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras import models as models_lib from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.metrics import Metric from tensorflow.python.keras.models import model_from_json from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables @@ -276,42 +277,40 @@ def _create_signature_def_map(model, mode): inputs_dict.update(targets_dict) outputs_dict = {name: x for name, x in zip(model.output_names, model.outputs)} + metrics = estimator_keras_util._convert_keras_metrics_to_estimator(model) + + # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables + # are by default not added to any collections. We are doing this here, so + # that metric variables get initialized. + local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + vars_to_add = set() + if metrics is not None: + for key, value in six.iteritems(metrics): + if isinstance(value, Metric): + vars_to_add.update(value.variables) + # Convert Metric instances to (value_tensor, update_op) tuple. + metrics[key] = (value.result(), value.updates[0]) + # Remove variables that are in the local variables collection already. + vars_to_add = vars_to_add.difference(local_vars) + for v in vars_to_add: + ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) + export_outputs = model_fn_lib.export_outputs_for_mode( mode, predictions=outputs_dict, loss=model.total_loss if model.optimizer else None, - metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model)) + metrics=metrics) return export_helpers.build_all_signature_defs( inputs_dict, export_outputs=export_outputs, serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) -def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): +def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument """Assert model and clone contain the same checkpointable objects.""" - def get_non_optimizer_objects(m, g): - """Gather set of model and optimizer checkpointable objects.""" - # Set default graph because optimizer.variables() returns optimizer - # variables defined in the default graph. - with g.as_default(): - all_objects = set(checkpointable_utils.list_objects(m)) - optimizer_and_variables = set() - for obj in all_objects: - if isinstance(obj, optimizers.TFOptimizer): - optimizer_and_variables.update(checkpointable_utils.list_objects(obj)) - optimizer_and_variables.update(set(obj.optimizer.variables())) - return all_objects - optimizer_and_variables - - model_objects = get_non_optimizer_objects(model, model_graph) - clone_objects = get_non_optimizer_objects(clone, clone_graph) - - if len(model_objects) != len(clone_objects): - raise errors.InternalError( - None, None, - 'Model and clone must use the same variables.' - '\n\tModel variables: %s\n\t Clone variables: %s' - % (model_objects, clone_objects)) + # TODO(fchollet, kathywu): make sure this works in eager mode. + return True def load_keras_model(saved_model_path): diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 364b65e06a3cdccc5ec23ddca2403bb28e38598e..d8637effe2ba88689d591482b067ac6f4a1683c1 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -29,7 +29,6 @@ from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import training @@ -150,8 +149,6 @@ class TestModelSavingandLoading(test.TestCase): x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) - model.train_on_batch(x, y) - ref_y = model.predict(x) temp_saved_model = self._save_model_dir() @@ -308,6 +305,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): self, model_builder, uses_learning_phase, optimizer, train_before_export): saved_model_path = self._save_model_dir() with self.session(graph=ops.Graph()): + np.random.seed(130) input_arr = np.random.random((1, 3)) target_arr = np.random.random((1, 3)) @@ -346,16 +344,24 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): inputs, outputs = load_model(sess, output_path, model_fn_lib.ModeKeys.EVAL) - eval_results = sess.run(outputs, {inputs[input_name]: input_arr, - inputs[target_name]: target_arr}) + # First obtain the loss and predictions, and run the metric update op by + # feeding in the inputs and targets. + loss, predictions, _ = sess.run( + (outputs['loss'], outputs['predictions/' + output_name], + outputs['metrics/mean_absolute_error/update_op']), { + inputs[input_name]: input_arr, + inputs[target_name]: target_arr + }) + + # The metric value should be run after the update op, to ensure that it + # reflects the correct value. + metric_value = sess.run(outputs['metrics/mean_absolute_error/value']) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) - self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05) - self.assertAllClose( - ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05) - self.assertAllClose( - ref_predict, eval_results['predictions/' + output_name], atol=1e-05) + self.assertAllClose(ref_loss, loss, atol=1e-05) + self.assertAllClose(ref_mae, metric_value, atol=1e-05) + self.assertAllClose(ref_predict, predictions, atol=1e-05) # Load train graph, and check for the train op, and prediction values with session.Session(graph=ops.Graph()) as sess: @@ -364,8 +370,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) self.assertIn('loss', outputs) - self.assertIn('metrics/mae/update_op', outputs) - self.assertIn('metrics/mae/value', outputs) + self.assertIn('metrics/mean_absolute_error/update_op', outputs) + self.assertIn('metrics/mean_absolute_error/value', outputs) self.assertIn('predictions/' + output_name, outputs) # Train for a step @@ -458,11 +464,6 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001)) clone.train_on_batch(input_arr, target_arr) - with self.assertRaisesRegexp( - errors.InternalError, 'Model and clone must use the same variables.'): - keras_saved_model._assert_same_non_optimizer_objects( - model, model_graph, clone, clone_graph) - def testSaveSeqModelWithoutInputShapesRaisesError(self): """A Sequential model that hasn't been built should raise an error.""" model = sequential_model_without_input_shape(True) diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 6bd58c4d322c04d4d14d04678e24a05c0f876208..5e4f130b31483204a111e2f778fa5d0fc4526fea 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -4,129 +4,11 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "py_test") # @unused - py_library( name = "signal_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:spectral_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_library( - name = "test_util", - srcs = ["python/kernel_tests/test_util.py"], + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python:training", - ], -) - -cuda_py_tests( - name = "mel_ops_test", - srcs = ["python/kernel_tests/mel_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - ], -) - -cuda_py_tests( - name = "mfcc_ops_test", - srcs = ["python/kernel_tests/mfcc_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:spectral_ops_test_util", - ], -) - -cuda_py_tests( - name = "reconstruction_ops_test", - srcs = ["python/kernel_tests/reconstruction_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "shape_ops_test", - srcs = ["python/kernel_tests/shape_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "spectral_ops_test", - size = "large", - srcs = ["python/kernel_tests/spectral_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops_test_util", - ], - tags = ["nomac"], -) - -cuda_py_tests( - name = "window_ops_test", - srcs = ["python/kernel_tests/window_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python/ops/signal", ], ) diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index d088e744346aac0aa8675b95d7b792379fc7b019..d01f5ccf51c132082a419ec7db49045ef8bab725 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -14,6 +14,9 @@ # ============================================================================== """Signal processing operations. +`tf.contrib.signal` has been renamed to `tf.signal`. `tf.contrib.signal` will be +removed in TensorFlow 2.0. + See the [Contrib Signal](https://tensorflow.org/api_guides/python/contrib.signal) guide. @@ -39,18 +42,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix -from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms -from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add -from tensorflow.contrib.signal.python.ops.shape_ops import frame +from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix +from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms +from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add +from tensorflow.python.ops.signal.shape_ops import frame +from tensorflow.python.ops.signal.spectral_ops import inverse_stft +from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn +from tensorflow.python.ops.signal.spectral_ops import stft +from tensorflow.python.ops.signal.window_ops import hamming_window +from tensorflow.python.ops.signal.window_ops import hann_window + +from tensorflow.python.util.all_util import remove_undocumented + # `frame` used to be named `frames`, which is a noun and not a verb. # Keep an alias to `frames` for backwards compatibility. -from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft_window_fn -from tensorflow.contrib.signal.python.ops.spectral_ops import stft -from tensorflow.contrib.signal.python.ops.window_ops import hamming_window -from tensorflow.contrib.signal.python.ops.window_ops import hann_window +frames = frame -from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index 42898e797cc351e3de290cc65fc825f1406c739d..605625c3059868d349da015b8286d219691fc255 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -79,6 +79,7 @@ from tensorflow.python.ops.summary_ops_v2 import image from tensorflow.python.ops.summary_ops_v2 import import_event from tensorflow.python.ops.summary_ops_v2 import initialize from tensorflow.python.ops.summary_ops_v2 import never_record_summaries +from tensorflow.python.ops.summary_ops_v2 import record_summaries from tensorflow.python.ops.summary_ops_v2 import record_summaries_every_n_global_steps from tensorflow.python.ops.summary_ops_v2 import scalar from tensorflow.python.ops.summary_ops_v2 import should_record_summaries diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 4d1807130c57039976dfa57c27bb0d4807e75212..10e4556dacbc17ec02c2bd698389b04d517d7076 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -152,6 +152,27 @@ class EagerFileTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + def testRecordEveryNGlobalSteps(self): + step = training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + + def run_step(): + summary_ops.scalar('scalar', i, step=step) + step.assign_add(1) + + with summary_ops.create_file_writer( + logdir).as_default(), summary_ops.record_summaries_every_n_global_steps( + 2, step): + for i in range(10): + run_step() + # And another 10 steps as a graph function. + run_step_fn = function.defun(run_step) + for i in range(10): + run_step_fn() + + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 11) + def testMaxQueue(self): logs = tempfile.mkdtemp() with summary_ops.create_file_writer( @@ -279,12 +300,9 @@ class EagerDbTest(summary_test_util.SummaryDbTest): def testDbURIOpen(self): tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite') - tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path) - tmpdb_writer = summary_ops.create_db_writer( - tmpdb_uri, - "experimentA", - "run1", - "user1") + tmpdb_uri = six.moves.urllib_parse.urljoin('file:', tmpdb_path) + tmpdb_writer = summary_ops.create_db_writer(tmpdb_uri, 'experimentA', + 'run1', 'user1') with summary_ops.always_record_summaries(): with tmpdb_writer.as_default(): summary_ops.scalar('t1', 2.0) diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 596c59ead3460aa63eeff44d5a11a4a8c5cde0da..290c16fe3966791ea78986539750caf938a37322 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops # pylint: disable=unused-import @@ -28,10 +30,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_mod # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _model_ops = loader.load_op_library( @@ -88,6 +92,59 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) +class TreeVariable(tracking.TrackableResource): + """A tree model.""" + + def __init__(self, params, tree_config, stats_handle, name, container=None): + self._params = params + self._tree_config = tree_config + self._stats_handle = stats_handle + self._name = name + self._container = container + self._init_op = None + super(TreeVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "tree_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_model_ops.decision_tree_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_variable( + self.resource_handle, + self._tree_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "tree_variable": + functools.partial( + TreeVariableSavable, + params=self._params, + tree_handle=self.resource_handle, + stats_handle=self._stats_handle, + create_op=self._init_op) + } + + def tree_variable(params, tree_config, stats_handle, name, container=None): r"""Creates a tree model and returns a handle to it. @@ -102,18 +159,13 @@ def tree_variable(params, tree_config, stats_handle, name, container=None): A `Tensor` of type mutable `string`. The handle to the tree. """ with ops.name_scope(name, "TreeVariable") as name: - resource_handle = gen_model_ops.decision_tree_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_model_ops.create_tree_variable( - resource_handle, - tree_config, - params=params.serialized_params_proto) - is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle) + tree_var = TreeVariable(params, tree_config, stats_handle, name, container) + resource_handle = tree_var.resource_handle + create_op = tree_var.initializer + is_initialized_op = tree_var.is_initialized() # Adds the variable to the savable list. - saveable = TreeVariableSavable(params, resource_handle, stats_handle, - create_op, - resource_handle.name) + saveable = tree_var._gather_saveables_for_checkpoint()["tree_variable"]( # pylint: disable=protected-access + name=resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 44d486edecc4e4f7ba8a9b6d680178298813621b..9184198cd4c8fd2a7609714d094d5ef2b6868658 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_stats_ops # pylint: disable=unused-import from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import finalize_tree @@ -25,10 +27,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import process_in # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _stats_ops = loader.load_op_library( @@ -84,8 +88,58 @@ class FertileStatsVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) -def fertile_stats_variable(params, stats_config, name, - container=None): +class FertileStatsVariable(tracking.TrackableResource): + """A Fertile stats variable.""" + + def __init__(self, params, stats_config, name, container=None): + self._params = params + self._stats_config = stats_config + self._name = name + self._container = container + self._init_op = None + super(FertileStatsVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "fertile_stats_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_stats_ops.fertile_stats_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_stats_ops.create_fertile_stats_variable( + self.resource_handle, + self._stats_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_stats_ops.fertile_stats_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "fertile_stats_variable": + functools.partial( + FertileStatsVariableSavable, + params=self._params, + stats_handle=self.resource_handle, + create_op=self.initializer) + } + + +def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: @@ -98,17 +152,15 @@ def fertile_stats_variable(params, stats_config, name, A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: - resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_stats_ops.create_fertile_stats_variable( - resource_handle, stats_config, - params=params.serialized_params_proto) - is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( - resource_handle) + fertile_stats_var = FertileStatsVariable(params, stats_config, name, + container) + resource_handle = fertile_stats_var.resource_handle + create_op = fertile_stats_var.initializer + is_initialized_op = fertile_stats_var.is_initialized() # Adds the variable to the savable list. - saveable = FertileStatsVariableSavable(params, resource_handle, create_op, - resource_handle.name) + saveable = ( + fertile_stats_var._gather_saveables_for_checkpoint()[ # pylint: disable=protected-access + "fertile_stats_variable"](name=resource_handle.name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc index 3f24f58f03aac2ba6d368d7eccf8731f611a81b4..22b6f09d0cd88068f7bedabe7687920420a3028f 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc @@ -73,7 +73,16 @@ class SummaryFileWriter : public SummaryWriterInterface { e->set_step(global_step); e->set_wall_time(GetWallTime()); Summary::Value* v = e->mutable_summary()->add_value(); - t.AsProtoTensorContent(v->mutable_tensor()); + + if (t.dtype() == DT_STRING) { + // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python + // can convert the TensorProto to string-type numpy array. MakeNdarray + // does not work with strings encoded by AsProtoTensorContent() in + // tensor_content. + t.AsProtoField(v->mutable_tensor()); + } else { + t.AsProtoTensorContent(v->mutable_tensor()); + } v->set_tag(tag); if (!serialized_metadata.empty()) { v->mutable_metadata()->ParseFromString(serialized_metadata); diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc index cd3f712256f2293ed725745f8cbe48109856ef86..ffbfb9533e887e54b0f5bdfde11dadce21073a94 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_file_writer.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/path.h" @@ -104,6 +105,23 @@ TEST_F(SummaryFileWriterTest, WriteTensor) { CHECK_EQ(e.summary().value_size(), 1); EXPECT_EQ(e.summary().value(0).tag(), "name"); })); + TF_CHECK_OK(SummaryTestHelper( + "string_tensor_test", + [](SummaryWriterInterface* writer) { + Tensor hello(DT_STRING, TensorShape({})); + hello.scalar()() = "hello"; + TF_RETURN_IF_ERROR(writer->WriteTensor( + 2, hello, "name", SummaryMetadata().SerializeAsString())); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + EXPECT_EQ(e.summary().value(0).tensor().dtype(), DT_STRING); + EXPECT_EQ(e.summary().value(0).tensor().string_val()[0], "hello"); + })); } TEST_F(SummaryFileWriterTest, WriteScalar) { diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index f30c31a78925801da69ebf4e950d70b018cb15d2..d304d72c6aa4092b7f6afdd6859847bd37c93e95 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -29,6 +29,10 @@ load( "if_tensorrt", ) +exports_files(glob([ + "test/testdata/*", +])) + tf_cuda_cc_test( name = "tensorrt_test_cc", size = "small", @@ -550,6 +554,30 @@ cuda_py_tests( ], ) +cuda_py_test( + name = "quantization_mnist_test", + srcs = ["test/quantization_mnist_test.py"], + additional_deps = [ + ":tf_trt_integration_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/keras:keras", + "//tensorflow/python/estimator:estimator", + ], + data = [ + "test/testdata/checkpoint", + "test/testdata/model.ckpt-46900.data-00000-of-00001", + "test/testdata/model.ckpt-46900.index", + ], + tags = [ + "no_cuda_on_cpu_tap", + "no_pip", + "no_tap", # It is not able to download the mnist data. + "no_windows", + "nomac", + ], +) + cc_library( name = "utils", srcs = ["convert/utils.cc"], diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b1443e7791d6d3c5ed8952d29df12ede77fe8c23..f6d44cb719123ac55ea8c56c34d157e87e244626 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -82,81 +82,79 @@ std::vector GetLoadedTensorRTVersion() { } TrtCandidateSelector::TrtCandidateSelector( - const grappler::GraphProperties& graph_properties, - int precision_mode) + const grappler::GraphProperties& graph_properties, int precision_mode) : graph_properties_(graph_properties), precision_mode_(precision_mode) {} Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", - "Exp", - "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", - "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", - "Max", - "Min", - "Relu6", + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", + "Transpose", + "Reshape", + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", + "Relu6", }; - bool is_supported_op_type = (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); -#if NV_TENSORRT_MAJOR >= 5 + bool is_supported_op_type = + (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); static const std::set quantize_ops = { - "QuantizeAndDequantizeV2", - "QuantizeAndDequantizeV3", - "FakeQuantWithMinMaxVars", - "FakeQuantWithMinMaxArgs", + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxArgs", }; // In INT8 mode, we will always apply the quantization ranges provided by // these ops to the relevant tensors. This happens regardless of the value of // use_calibration. - if (precision_mode_ == INT8MODE && - quantize_ops.count(node->type_string())) { + if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) { is_supported_op_type = true; } -#endif // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) if (!is_supported_op_type) { return errors::Unimplemented("Op type ", node->type_string(), - " is not supported."); + " is not supported"); } std::vector input_edges; TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); std::vector> input_node_and_ports; + input_node_and_ports.reserve(input_edges.size()); for (const Edge* input_edge : input_edges) { input_node_and_ports.emplace_back(&input_edge->src()->def(), input_edge->src_output()); @@ -280,7 +278,9 @@ tensorflow::Status ConvertGraphDefToTensorRT( #endif // Create RewriterConfig. - tensorflow::RewriterConfig rw_cfg; + tensorflow::ConfigProto config_proto; + auto& rw_cfg = + *config_proto.mutable_graph_options()->mutable_rewrite_options(); // TODO(aaroey): use only const folding and layout for the time being since // new optimizers break the graph for trt. rw_cfg.add_optimizers("constfold"); @@ -304,7 +304,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( parameters["use_calibration"].set_b(use_calibration); // Run optimizer. - tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); + tensorflow::grappler::MetaOptimizer meta_opt(nullptr, config_proto); TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); if (VLOG_IS_ON(5)) { @@ -582,11 +582,11 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, } } - const bool calibrate_int8 = (info.precision_mode == INT8MODE && info.use_calibration); + const bool calibrate_int8 = + (info.precision_mode == INT8MODE && info.use_calibration); // Build the engine and get its serialized representation. string segment_string; - if (info.engine_type == EngineInfo::EngineType::TRTStatic || - calibrate_int8) { + if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) { // Create static engine for fp32/fp16 mode, and test validity of the engine // for int8 calibration mode. We don't want engine to fail at the // calibration time. So we are constructing a FP32 engine here to check its @@ -596,8 +596,7 @@ tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, TrtUniquePtrType engine; // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( - info.segment_graph_def, - calibrate_int8 ? FP32MODE : info.precision_mode, + info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, &trt_logger, alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, @@ -927,12 +926,12 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { } curr_engine.precision_mode = params.precision_mode; if (params.use_calibration && params.precision_mode != INT8MODE) { - return tensorflow::errors::Unimplemented( - "Calibration with FP32 or FP16 is not implemented. "); + return errors::InvalidArgument( + "Calibration with FP32 or FP16 is not supported."); } curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration) - ? EngineInfo::EngineType::TRTDynamic - : EngineInfo::EngineType::TRTStatic); + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); curr_engine.use_calibration = params.use_calibration; curr_engine.cached_engine_batches = params.cached_engine_batches; curr_engine.maximum_cached_engines = params.max_cached_engines; @@ -952,7 +951,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { converted_segments.push_back(std::move(curr_segment)); if (VLOG_IS_ON(8)) { - string fname = curr_engine.engine_name; + string fname = engine_segments.back().engine_name; StrAppend(&fname, ".pb"); std::fstream f; f.open(fname.c_str(), std::fstream::out | std::fstream::binary); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 366d69115b2ee000d77e088539a06ac8c88134ee..2904e73abc01b11400f73cd5779c1aeceb0d7e0b 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -53,6 +53,9 @@ class TrtCandidateSelector { // GraphProperties of the graph whose nodes are to be validated by // IsTensorRTCandidate(). const grappler::GraphProperties& graph_properties_; + + // Quantization ops are only converted when using quantized precisions. + const int precision_mode_; }; struct ConversionParams { @@ -101,8 +104,7 @@ tensorflow::Status ConvertGraphDefToTensorRT( size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode = 1, int minimum_segment_size = 3, bool is_dyn_op = false, int max_cached_engines = 1, - std::vector cached_engine_batches = {}, - bool use_calibration = true); + std::vector cached_engine_batches = {}, bool use_calibration = true); // Method to call from optimization pass tensorflow::Status ConvertAfterShapes(ConversionParams& params); diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc index e2a7c40f301891b99c7cadb9f526233c0b81f461..2d2bfeb192c1893824c7b30bfad593c62c203392 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -85,27 +85,42 @@ TEST(TrtCandidateSelector, Basics) { ops::MatMul(s.WithOpName("matmul_with_incompatible_input"), incompatible_feed, const_2); + // Quantize ops. + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); + auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed, + quantize_attrs); + + // Get GrapplerItem and GraphProperties. grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); Tensor feed_tensor(DT_FLOAT, input_shape); item.feed.push_back(std::make_pair("feed", feed_tensor)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - TrtCandidateSelector selector(graph_properties, FP32MODE); - TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); - ExpectStatus( - selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), - error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected " - "(op: MatMul), at: incompatible_matmul"); - ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), - error::UNIMPLEMENTED, "Op type Sin is not supported"); - ExpectStatus(selector.IsTensorRTCandidate( - matmul_with_incompatible_input.operation.node()), - error::INTERNAL, - "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + for (const int precision_mode : {FP32MODE, INT8MODE}) { + TrtCandidateSelector selector(graph_properties, precision_mode); + TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node())); + ExpectStatus( + selector.IsTensorRTCandidate(incompatible_matmul.operation.node()), + error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected " + "(op: MatMul), at: incompatible_matmul"); + ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()), + error::UNIMPLEMENTED, "Op type Sin is not supported"); + ExpectStatus( + selector.IsTensorRTCandidate( + matmul_with_incompatible_input.operation.node()), + error::INTERNAL, + "Failed to convert input with index 0 to a TRT_TensorOrWeights"); + if (precision_mode == INT8MODE) { + TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node())); + } else { + ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()), + error::UNIMPLEMENTED, + "Op type FakeQuantWithMinMaxArgs is not supported"); + } + } } class FakeCluster : public grappler::Cluster { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index bc29bf8ffacf53492da2591fdc56513ac8f8c694..cb2a1ca87ac7434e7480ee09f14071b67f107410 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -54,10 +54,10 @@ limitations under the License. // would work! #define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) -#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ - do { \ - return tensorflow::errors::Internal( \ - "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, " failed to add TRT layer, at: ", node); \ } while (0) #define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ @@ -187,10 +187,49 @@ Status ValidateTensorProperties(const string& producer_node_type, return Status::OK(); } +string DebugString(const nvinfer1::DimensionType type) { + switch (type) { + case nvinfer1::DimensionType::kSPATIAL: + return "kSPATIAL"; + case nvinfer1::DimensionType::kCHANNEL: + return "kCHANNEL"; + case nvinfer1::DimensionType::kINDEX: + return "kINDEX"; + case nvinfer1::DimensionType::kSEQUENCE: + return "kSEQUENCE"; + default: + return StrCat(static_cast(type), "=unknown"); + } +} + +string DebugString(const nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return "kFLOAT"; + case nvinfer1::DataType::kHALF: + return "kHALF"; + case nvinfer1::DataType::kINT8: + return "kINT8"; + case nvinfer1::DataType::kINT32: + return "kINT32"; + default: + return "Invalid TRT data type"; + } +} + string DebugString(const nvinfer1::Dims& dims) { string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); for (int i = 0; i < dims.nbDims; ++i) { - StrAppend(&out, dims.d[i], ","); + StrAppend(&out, dims.d[i], "[", DebugString(dims.type[i]), "],"); + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::Permutation& permutation, int len) { + string out = "nvinfer1::Permutation("; + for (int i = 0; i < len; ++i) { + StrAppend(&out, permutation.order[i], ","); } StrAppend(&out, ")"); return out; @@ -198,16 +237,15 @@ string DebugString(const nvinfer1::Dims& dims) { string DebugString(const nvinfer1::ITensor& tensor) { return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), - ", shape=", DebugString(tensor.getDimensions()), ")"); + ", name=", tensor.getName(), + ", dtype=", DebugString(tensor.getType()), + ", dims=", DebugString(tensor.getDimensions()), ")"); } -// Return whether or not the broadcast is feasible; -bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, - const bool operand_l_is_tensor, - const nvinfer1::Dims& operand_r, - const bool operand_r_is_tensor, - nvinfer1::Dims* operand_l_new_shape, - nvinfer1::Dims* operand_r_new_shape) { +Status Converter::GetTrtBroadcastShape( + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const { // *************************************************************************** // TensorRT Elementwise op supports broadcast but requires both tensor to be // of Identical rank @@ -232,52 +270,59 @@ bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, // -> T: 1 1 1 -1 3 5 1 // -> W: 1 1 1 1 3 5 1 // *************************************************************************** - const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; - const size_t element_size = sizeof(operand_l.d[0]); - - // fill in dimensions - int l_s[max_nb_dims]; - std::fill(l_s, l_s + max_nb_dims, 1); - int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; - int r_s[max_nb_dims]; - std::fill(r_s, r_s + max_nb_dims, 1); - int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; - - int max_d = std::max(l_d, r_d); - std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, - operand_l.nbDims * element_size); - std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, - operand_r.nbDims * element_size); - - // set -1 for batch dimension, since batch size is not supposed to be - // broadcasted - if (operand_l_is_tensor) { - if (max_d != l_d) { // if broadcast beyond batch dimension, fail - return false; - } - l_s[0] = -1; - } - if (operand_r_is_tensor) { - if (max_d != r_d) { // if broadcast beyond batch dimension, fail - return false; - } - r_s[0] = -1; + if (!operand_l.is_tensor() && !operand_r.is_tensor()) { + return errors::InvalidArgument( + "Broadcasting requires at least one of the operands be tensors"); } - // compare broadcast feasibility - for (int i = max_d - 1; i >= 0; i--) { - if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { - return false; + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + auto compute_output_dims = + [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims, + int* output_dims_array, nvinfer1::Dims* output_dims) { + const nvinfer1::Dims input_dims = input.GetTrtDims(); + std::fill(output_dims_array, output_dims_array + max_nb_dims, 1); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, + output_dims_array + broadcast_num_dims - input_dims.nbDims); + if (input.is_tensor()) { + const int true_input_dims = input_dims.nbDims + 1; + if (true_input_dims < broadcast_num_dims) { + return errors::InvalidArgument( + "Broadcasting beyond batch dimension is not supported ", + "(tensor #dims ", true_input_dims, " vs broadcast #dims ", + broadcast_num_dims, ")"); + } + // Set the batch dimension to -1, since batch size is not supposed to + // be broadcasted. + output_dims_array[0] = -1; + } + // Copy to output dimensions (stripping the batch dimension). + output_dims->nbDims = broadcast_num_dims - 1; + std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims, + output_dims->d); + return Status::OK(); + }; + + // Compute the output dimensions. + const int broadcast_num_dims = + std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0), + operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0)); + int output_l[max_nb_dims], output_r[max_nb_dims]; + TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims, + output_l, operand_l_new_dims)); + TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims, + output_r, operand_r_new_dims)); + + // Compare broadcast feasibility + for (int i = 0; i < broadcast_num_dims; ++i) { + if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && + (output_r[i] != 1)) { + return errors::InvalidArgument( + "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ", + DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0], + ", ", DebugString(*operand_r_new_dims), ")"); } } - - // output new TensorRT Dimension (stripping the batch dimension) - operand_l_new_shape->nbDims = max_d - 1; - std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); - operand_r_new_shape->nbDims = max_d - 1; - std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); - - return true; + return Status::OK(); } inline bool DimsEqual(const nvinfer1::Dims& dim_l, @@ -381,7 +426,7 @@ size_t TRT_ShapedWeights::size_bytes() const { string TRT_ShapedWeights::DebugString() const { return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", type_, + ", type=", DataTypeString(type_), ", values=", reinterpret_cast(GetValues()), ")"); } @@ -491,8 +536,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor @", reinterpret_cast(tensor()), - ", shape=", convert::DebugString(tensor()->getDimensions()), + StrAppend(&output, "tensor=", convert::DebugString(*tensor()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); @@ -755,8 +799,9 @@ Status TrtNodeValidator::ValidateNode( Status status = ConvertToTensorOrWeights( *pair.first, pair.second, graph_properties, &tensor_or_weights); if (!status.ok()) { - return errors::Internal("Failed to convert input with index ", i, - " to a TRT_TensorOrWeights"); + return errors::Internal( + "Failed to convert input with index ", i, + " to a TRT_TensorOrWeights: ", status.error_message()); } inputs.push_back(tensor_or_weights); } @@ -789,8 +834,7 @@ Status TrtNodeValidator::ConvertConstToWeights( } Converter::Converter(nvinfer1::INetworkDefinition* trt_network, - int precision_mode, - bool use_calibration) + int precision_mode, bool use_calibration) : trt_network_(trt_network), precision_mode_(precision_mode), use_calibration_(use_calibration) { @@ -947,6 +991,8 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order_with_batch_dim[i + 1] - 1; } + VLOG(1) << "TransposeTensor permutation: " + << DebugString(permutation, dims.nbDims); layer->setFirstTranspose(permutation); nvinfer1::Dims reshape_dims; @@ -963,24 +1009,23 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, } Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, - float* out_min, - float* out_max) const { + float* out_min, float* out_max) const { switch (weights.type_) { - case tensorflow::DataType::DT_FLOAT: { + case DataType::DT_FLOAT: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = *result.first; *out_max = *result.second; break; } - case tensorflow::DataType::DT_HALF: { + case DataType::DT_HALF: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = Eigen::half_impl::half_to_float(*result.first); *out_max = Eigen::half_impl::half_to_float(*result.second); break; } - case tensorflow::DataType::DT_INT32: { + case DataType::DT_INT32: { auto inp = static_cast(weights.GetValues()); auto result = std::minmax_element(inp, inp + weights.count()); *out_min = static_cast(*result.first); @@ -988,11 +1033,11 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, break; } default: - return tensorflow::errors::Unimplemented( - "Data type not supported for GetWeightRange: " + - tensorflow::DataTypeString(weights.type_)); + return errors::Unimplemented( + "Data type not supported for GetWeightRange: ", + DataTypeString(weights.type_)); } - return tensorflow::Status::OK(); + return Status::OK(); } Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, @@ -1009,8 +1054,9 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, } if (can_check_shapes && TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) { - return tensorflow::errors::InvalidArgument( - "Reshape shapes are not compatible."); + return errors::InvalidArgument("Reshape shapes are not compatible (", + DebugString(input.GetTrtDims()), " vs ", + DebugString(dims), ")"); } if (input.is_tensor()) { @@ -1038,15 +1084,15 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, float max_range = 0.0f; TF_RETURN_IF_ERROR( GetWeightRange(input.weights(), &min_range, &max_range)); - // Avoid setting range to 0 because TRT will throw an error. If the weights - // are zero then the range doesn't matter: using 127.0f should ensure the - // quantized weight will be exactly zero. + // Avoid setting range to 0 because TRT will throw an error. If the + // weights are zero then the range doesn't matter: using 127.0f should + // ensure the quantized weight will be exactly zero. if (min_range == 0.0f && max_range == 0.0f) { min_range = -127.0f; max_range = 127.0f; } ProvideQuantizationRange(const_cast(*tensor), - min_range, max_range); + min_range, max_range); } } return tensorflow::Status::OK(); @@ -1064,25 +1110,29 @@ void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor, quantization_ranges_[tensor] = symmetric_range; } -void Converter::ApplyQuantizationRanges(bool warn_missing_ranges) { - // Infer ranges across marked ops +void Converter::MaybeApplyQuantizationRanges() { + if (precision_mode() != INT8MODE) return; + + // Infer ranges across marked ops. PropagateQuantizationRanges(); - // Apply ranges + // Apply ranges. +#if NV_TENSORRT_MAJOR >= 5 for (auto pair : quantization_ranges_) { nvinfer1::ITensor* tensor = pair.first; const float range = pair.second; -#if NV_TENSORRT_MAJOR >= 5 VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range; + // TODO(laigd): if 'tensor' already has a range set which doesn't match + // 'range', it should report error. tensor->setDynamicRange(-range, range); -#endif } +#endif // Warn user about tensors that are missing ranges. If TRT fuses some layers // then these tensors may not actually be required, which is why this is // just a warning. If we are still missing ranges even after fusion, // Builder::buildCudaEngine() will return nullptr and we will catch the // error at that point. - if (warn_missing_ranges) { + if (!use_calibration()) { // Get all tensors from network std::set all_tensors; for (int i = 0; i < this->network()->getNbLayers(); i++) { @@ -1096,15 +1146,15 @@ void Converter::ApplyQuantizationRanges(bool warn_missing_ranges) { } // Find tensors with no ranges for (auto tensor : all_tensors) { - if (quantization_ranges_.find(tensor) == quantization_ranges_.end()) { + if (!quantization_ranges_.count(tensor)) { // Note: there may be some warnings for "(Unnamed ITensor* N)". These // are tensors which are created internally by TF-TRT. The ranges for // these unnamed ITensors are always inferred from user provided ranges, // thus there will also be a warning for the range(s) the user missed. LOG(WARNING) << "Quantization range was not found for " - << tensor->getName() << ". " - << "This is okay if TensorRT does not need the range " - << "(e.g. due to node fusion)."; + << tensor->getName() << ". " + << "This is okay if TensorRT does not need the range " + << "(e.g. due to node fusion)."; } } } @@ -1120,20 +1170,21 @@ void Converter::PropagateQuantizationRanges() { while (information_added) { information_added = false; for (auto it = quantization_infer_.begin(); - it != quantization_infer_.end();) { + it != quantization_infer_.end();) { auto input_tensor_range = quantization_ranges_.find(it->first); auto output_tensor_range = quantization_ranges_.find(it->second); if (input_tensor_range != quantization_ranges_.end() && output_tensor_range == quantization_ranges_.end()) { // Input has range but output doesn't: copy range + // TODO(laigd): consider reporting error if it a different range is + // already set. quantization_ranges_[it->second] = input_tensor_range->second; information_added = true; - VLOG(1) << "Copy quantization range: " - << it->first->getName() << " -> " << it->second->getName(); + VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> " + << it->second->getName(); } // We can remove edges when the output range is known - if (quantization_ranges_.find(it->second) != - quantization_ranges_.end()) { + if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) { it = quantization_infer_.erase(it); } else { ++it; @@ -1198,12 +1249,11 @@ TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store, } // **************************************************************************** -// Constant folding functions -// TODO(jie): once optimizer kicks in, we should have done constant folding -// there. +// Constant folding functions for weights. +// TODO(laigd): we should probably use eigen directly. // ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP }; OP_CATEGORY op; template @@ -1218,84 +1268,10 @@ struct LambdaFactory { case OP_CATEGORY::RECIP: return [](T t) -> T { return 1.0 / t; }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } - - template - std::function binary() { - switch (op) { - case OP_CATEGORY::ADD: - return [](T l, T r) -> T { return l + r; }; - case OP_CATEGORY::SUB: - return [](T l, T r) -> T { return l - r; }; - case OP_CATEGORY::MUL: - return [](T l, T r) -> T { return l * r; }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [](T l, T r) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_r(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l + val; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l - val; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return l * val; - }; - default: - LOG(WARNING) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } - - template - std::function broadcast_l(T val) { - VLOG(2) << "LAMBDA VAL : " << val; - switch (op) { - case OP_CATEGORY::ADD: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val + l; - }; - case OP_CATEGORY::SUB: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val - l; - }; - case OP_CATEGORY::MUL: - return [val](T l) -> T { - VLOG(2) << "LAMBDA VAL : " << val; - return val * l; - }; - default: - LOG(ERROR) << "Not supported op for binary: " << static_cast(op); - } - return [val](T l) -> T { - LOG(FATAL) << "Unsupported op type "; - return l; - }; - } }; template <> @@ -1303,15 +1279,18 @@ std::function LambdaFactory::unary() { switch (op) { case OP_CATEGORY::RSQRT: { VLOG(2) << "RSQRT GETS DONE"; - return [](Eigen::half t) -> Eigen::half { + return [](Eigen::half t) { return Eigen::half(1.0 / sqrt(static_cast(t))); }; } case OP_CATEGORY::NEG: - return [](Eigen::half t) -> Eigen::half { return -t; }; - // TODO(aaroey): can we support RECIP? + return [](Eigen::half t) { return -t; }; + case OP_CATEGORY::RECIP: + return [](Eigen::half t) { + return Eigen::half(1.0 / static_cast(t)); + }; default: - VLOG(2) << "Not supported op for unary: " << static_cast(op); + LOG(ERROR) << "Not supported op for unary: " << static_cast(op); return nullptr; } } @@ -1343,50 +1322,48 @@ tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, return tensorflow::Status::OK(); } +// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the +// right operand. If swapped_inputs is true, those two are swapped. +// // TODO(jie): broadcast is needed yet not implemented. -// Only implemented channel wise for the time being -tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, - const nvinfer1::ITensor* tensor, - TRT_ShapedWeights weights, - bool swapped_inputs) { +// Only implemented channel wise for the time being. +Status BinaryTensorOpWeight(OpConverterParams* params, + const nvinfer1::ITensor* tensor, + TRT_ShapedWeights weights, bool swapped_inputs) { + static const std::unordered_set supported_ops = {"Sub", "Add", "Mul", + "Div", "RealDiv"}; const auto& node_def = params->node_def; - // tensor is the left operand while weights is the right operand; - // when swapped_inputs set to true, those two are swapped. - // TODO(aaroey): use a set. - if (node_def.op() != "Sub" && node_def.op() != "Add" && - node_def.op() != "Mul" && node_def.op() != "Div" && - node_def.op() != "RealDiv") { - return tensorflow::errors::Unimplemented( - "op not supported: " + node_def.op() + ", at: " + node_def.name()); + if (!supported_ops.count(node_def.op())) { + return errors::Unimplemented(node_def.op(), " is not supported, at ", + node_def.name()); } - // Check type consistency - nvinfer1::DataType ttype; - TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype)); + // Check type consistency. + nvinfer1::DataType trt_dtype; + TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype)); - // Check scale mode + // Check scale mode. auto dims_w = weights.shape_; - auto dims_t = tensor->getDimensions(); + const auto dims_t = tensor->getDimensions(); // TODO(jie): addScale checks for input tensor dimension if (dims_t.nbDims != 3) { - return tensorflow::errors::InvalidArgument( - "addScale requires tensor with rank 3, " + node_def.name()); + return errors::InvalidArgument("addScale requires tensor with rank 3, at ", + node_def.name()); } - // default to element-wise + // Default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; // TODO(jie): maybe use a permutation instead to support more cases; - bool permutation_flag = false; + bool need_to_permute = false; if (weights.count() == 1) { - VLOG(2) << "UNIFORM"; scale_mode = nvinfer1::ScaleMode::kUNIFORM; } else { - // no broadcasting on Batch dimension; - VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims - << " tensor DIM: " << dims_t.nbDims; + VLOG(2) << "weights dims: " << DebugString(dims_w) + << "; tensor dims: " << DebugString(dims_t); + // Make sure no broadcasting on batch dimension. if (dims_w.nbDims == dims_t.nbDims + 1) { if (dims_w.d[0] == 1) { for (int i = 1; i < dims_w.nbDims; i++) { @@ -1394,72 +1371,70 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } dims_w.nbDims--; } else { - return tensorflow::errors::InvalidArgument( - "Binary op cannot operate on batch, " + node_def.name()); + return errors::InvalidArgument("Binary op cannot operate on batch, at ", + node_def.name()); } } if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) { scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - // default is element; + // Default is element-wise for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != dims_t.d[i]) { - // if dimension does not match, switch back to channel; - VLOG(2) << "channel"; + // If dimension does not match, switch back to per-channel scale_mode = nvinfer1::ScaleMode::kCHANNEL; break; } } - // if channel as candidate, validate it + // If the mode is per-channel, since channel dimension is assumed to be + // the third to last dimension, we need to make sure all other dimensions + // have size 1. if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { for (int i = 1; i < dims_w.nbDims; i++) { if (dims_w.d[i] != 1) - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument( + "Weight dims not compatible for channel-wise broadcast at ", + node_def.name()); } - } else { - VLOG(2) << "elementwise"; } } else if (dims_w.nbDims == 1 && dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) { - // channel wise and broadcast required; - permutation_flag = true; + // Channel wise and broadcast required. We compare the last dimension of + // the tensor shape because of tensorflow default broadcasting rules. + need_to_permute = true; scale_mode = nvinfer1::ScaleMode::kCHANNEL; } else { - return tensorflow::errors::InvalidArgument( - "Weight shape not compatible at, " + node_def.name()); + return errors::InvalidArgument("Weight dims not compatible at ", + node_def.name()); } } + // TODO(laigd): we should add validation_only support in TransposeTensor() and + // PrepareTensorForShape(). + if (params->validation_only) return Status::OK(); - // transpose last dimension + // Transpose last dimension. std::vector permutation(dims_t.nbDims + 1); - if (permutation_flag) { - if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) { - // we swap the last dimension into channel for trt. - // because of tensorflow default broadcasting rules. - for (int i = 0; i < static_cast(permutation.size()); i++) { - permutation[i] = i; - } - permutation[1] = dims_t.nbDims; - permutation[dims_t.nbDims] = 1; - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - const_cast(tensor), permutation, &tensor)); - } else { - return tensorflow::errors::InvalidArgument( - "Transpose cannot be applied, " + node_def.name()); - } + if (need_to_permute) { + // We swap the last dimension into channel for trt, because of tensorflow + // default broadcasting rules. + for (int i = 0; i < static_cast(permutation.size()); i++) { + permutation[i] = i; + } + permutation[1] = dims_t.nbDims; + permutation[dims_t.nbDims] = 1; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + const_cast(tensor), permutation, &tensor)); } if (params->converter->precision_mode() == FP16MODE) { weights = ConvertFP32ToFP16(params->weight_store, weights); } - // prepare weights + // Prepare weights TRT_ShapedWeights shift_weights(weights.type_); TRT_ShapedWeights scale_weights(weights.type_); TRT_ShapedWeights power_weights(weights.type_); - // Maybe I should do a switch if (node_def.op() == "Sub") { if (swapped_inputs) { shift_weights = weights; @@ -1482,19 +1457,21 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { if (swapped_inputs) { - // We need to infer the quantization range for this intermediate - // tensor. - // x -> [Recip] -> 1/x -> [Scale] -> s/x - // ^ - // need range for this + // We need to infer the quantization range for this intermediate tensor. + // + // x -> [Recip] -> 1/x -> [Scale] -> s/x + // ^ + // need range for this + // // We have the quantization scales for x and s/x - can we divide the scale - // for s/x by s? Only if it was a scalar... + // for s/x by s? Only if it is a scalar. + // // Because of this issue, fall back to BinaryTensorOpTensor if we are // doing INT8 with no calibration. There is most likely no performance // penalty by falling back here. if (params->converter->precision_mode() == INT8MODE && !params->converter->use_calibration()) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Intermediate quantization range cannot be determined without" " calibration. Falling back to BinaryTensorOpTensor for ", node_def.op(), ", at ", node_def.name()); @@ -1518,8 +1495,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, } else if (node_def.op() == "Add") { shift_weights = weights; } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); + // This should not happen. + return errors::Unimplemented("Binary op not supported at ", node_def.op()); } nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( @@ -1529,8 +1506,8 @@ tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params, TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); const nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // transpose back dimension - if (permutation_flag) { + // Transpose back dimension + if (need_to_permute) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( const_cast(output_tensor), permutation, &output_tensor)); @@ -1664,9 +1641,9 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, params->node_def.name()); } -tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, - const TRT_TensorOrWeights& operand_l, - const TRT_TensorOrWeights& operand_r) { +Status BinaryTensorOpTensor(OpConverterParams* params, + const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r) { const auto& node_def = params->node_def; static const std::unordered_map ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, @@ -1677,50 +1654,52 @@ tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params, {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; + auto op_pair = ops.find(node_def.op()); + if (op_pair == ops.end()) { + return errors::Unimplemented("Binary op ", node_def.op(), + " not supported at: ", node_def.name()); + } - const nvinfer1::ITensor* tensor_l; - const nvinfer1::ITensor* tensor_r; - - nvinfer1::Dims dim_l; - nvinfer1::Dims dim_r; - - if (!TensorRTGetBroadcastShape(operand_l.GetTrtDims(), operand_l.is_tensor(), - operand_r.GetTrtDims(), operand_r.is_tensor(), - &dim_l, &dim_r)) { - return tensorflow::errors::InvalidArgument( - "Binary op broadcast scheme not supported by TensorRT op: " + - node_def.op() + ", at: " + node_def.name()); + nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r; + Status status = params->converter->GetTrtBroadcastShape( + operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r); + if (!status.ok()) { + return errors::InvalidArgument( + "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ", + status.error_message()); } + if (params->validation_only) return Status::OK(); - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_l, dim_l, &tensor_l)); - TF_RETURN_IF_ERROR( - params->converter->PrepareTensorForShape(operand_r, dim_r, &tensor_r)); + const nvinfer1::ITensor* tensor_l = nullptr; + const nvinfer1::ITensor* tensor_r = nullptr; + status = params->converter->PrepareTensorForShape( + operand_l, broadcasted_dims_l, &tensor_l); + if (status.ok()) { + status = params->converter->PrepareTensorForShape( + operand_r, broadcasted_dims_r, &tensor_r); + } + if (!status.ok()) { + return errors::Internal("Failed to convert binary op ", node_def.name(), + ": ", status.error_message()); + } - // get trt type & shape + // Check type consistency. TFAttrs attrs(node_def); - // maybe this part has to be moved into the block of rsqrt later nvinfer1::DataType dtype = attrs.get("T"); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype) + << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype) + << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype); - // check type consistency - TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); - TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); - auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) { - return tensorflow::errors::Unimplemented( - "binary op: ", node_def.op(), " not supported at: ", node_def.name()); - } - + // Add ElementWise layer. nvinfer1::IElementWiseLayer* layer = params->converter->network()->addElementWise( - // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast(tensor_l), *const_cast(tensor_r), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // pass the output + // Pass the output params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1730,6 +1709,7 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { const auto& node_def = params->node_def; // prepare input std::vector all_inputs; + all_inputs.reserve(inputs.size()); for (auto input : inputs) { all_inputs.emplace_back(const_cast(input.tensor())); } @@ -2008,23 +1988,22 @@ tensorflow::Status ConvertActivation(OpConverterParams* params) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertQuantize(OpConverterParams* params) { +Status ConvertQuantize(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if ((inputs.size() == 0) || - (inputs.size() != 1 && node_def.op() == "FakeQuantWithMinMaxArgs") || - (inputs.size() != 3 && node_def.op() == "FakeQuantWithMinMaxVars") || - (inputs.size() != 3 && node_def.op() == "QuantizeAndDequantizeV2") || - (inputs.size() != 4 && node_def.op() == "QuantizeAndDequantizeV3")) { - return tensorflow::errors::InvalidArgument( - "Invalid number of inputs for ", node_def.op(), ", at ", - node_def.name()); + (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) || + (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) || + (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) { + return errors::InvalidArgument("Invalid number of inputs for ", + node_def.op(), ", at ", node_def.name()); } if (inputs.at(0).is_weights()) { // TensorRT will automatically quantize weights, so we will ignore ranges // for weights. params->outputs->push_back(inputs.at(0)); - return tensorflow::Status::OK(); + return Status::OK(); } float min_range = 0.0f; float max_range = 0.0f; @@ -2032,9 +2011,8 @@ tensorflow::Status ConvertQuantize(OpConverterParams* params) { // Get ranges via node attributes. TFAttrs attrs(node_def); if (attrs.count("min") == 0 || attrs.count("max") == 0) { - return tensorflow::errors::InvalidArgument( - "Min or max attribute not found for ", node_def.op(), " at ", - node_def.name()); + return errors::InvalidArgument("Min or max attribute not found for ", + node_def.op(), " at ", node_def.name()); } min_range = attrs.get("min"); max_range = attrs.get("max"); @@ -2043,29 +2021,26 @@ tensorflow::Status ConvertQuantize(OpConverterParams* params) { node_def.op() == "QuantizeAndDequantizeV3") { // Get ranges via inputs. if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) { - return tensorflow::errors::InvalidArgument( - "Min and max inputs for ", node_def.op(), - " must be weights not tensors, at ", node_def.name()); - } - // Min - TRT_ShapedWeights weights_min = inputs.at(1).weights(); - auto weights_min_ptr = static_cast(const_cast( - weights_min.GetValues())); - min_range = weights_min_ptr[0]; - // Max - TRT_ShapedWeights weights_max = inputs.at(2).weights(); - auto weights_max_ptr = static_cast(const_cast( - weights_max.GetValues())); - max_range = weights_max_ptr[0]; + return errors::InvalidArgument("Min and max inputs for ", node_def.op(), + " must be weights not tensors, at ", + node_def.name()); + } + auto get_weights_value = [&inputs](int index) { + auto raw_weights = static_cast( + const_cast(inputs.at(index).weights().GetValues())); + return raw_weights[0]; + }; + min_range = get_weights_value(1); + max_range = get_weights_value(2); } else { - return tensorflow::errors::InvalidArgument( - "Unknown quantization op \"", node_def.op(), "\", at ", - node_def.name()); + return errors::InvalidArgument("Unknown quantization op ", node_def.op(), + ", at ", node_def.name()); } + if (params->validation_only) return Status::OK(); + // Store ranges for tensor params->converter->ProvideQuantizationRange( - const_cast(inputs.at(0).tensor()), - min_range, + const_cast(inputs.at(0).tensor()), min_range, max_range); // Sometimes, TRT may not quantize a tensor, either because it chooses to // execute a higher precision kernel or because of op fusion. In these cases, @@ -2077,7 +2052,7 @@ tensorflow::Status ConvertQuantize(OpConverterParams* params) { // possible (i.e. not quantizing in place where fusion will occur), then there // is no problem with the current implementation. params->outputs->push_back(inputs.at(0)); - return tensorflow::Status::OK(); + return Status::OK(); } // TODO(pdavoodi): we should update relu6 implementation once TensorRT supports @@ -2087,14 +2062,14 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { const auto& node_def = params->node_def; if (inputs.size() != 1) { return tensorflow::errors::InvalidArgument( - "Invalid number of inputs for Relu6, at ", - node_def.name()); + "Invalid number of inputs for Relu6, at ", node_def.name()); } if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( "Relu6 is only implemented for tensors, not weights, at ", node_def.name()); } + if (params->validation_only) return Status::OK(); // *************************************************************************** // TensorRT does not implement Relu6 natively. This function converts Relu6 op // to available TensorRT ops: Relu6(x) = min(Relu(x), 6) @@ -2110,12 +2085,12 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { nvinfer1::ActivationType::kRELU); TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); - // Large range of relu is problematic during quantization in INT8 precision mode. - // Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. + // Large range of relu is problematic during quantization in INT8 precision + // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. // TRT only uses dynamic ranges in INT8 precision mode, // and this does not affect the FP32 path. - params->converter->ProvideQuantizationRange( - relu_layer->getOutput(0), 0.0f, 6.0f); + params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f, + 6.0f); // Create a constant layer to store the floating point weight i.e. 6.0f This // tensor will be broadcasted uniformly during elementwise `min` operation. @@ -2128,14 +2103,14 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { } TRT_ShapedWeights weights = params->weight_store->GetTempWeights( tensorflow::DataType::DT_FLOAT, dims); - auto weights_ptr = static_cast(const_cast( - weights.GetValues())); - weights_ptr[0] = 6.f; + auto weights_ptr = + static_cast(const_cast(weights.GetValues())); + weights_ptr[0] = 6.0f; nvinfer1::IConstantLayer* const6_layer = params->converter->network()->addConstant(dims, weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name()); - params->converter->ProvideQuantizationRange( - const6_layer->getOutput(0), 0.0f, 6.0f); + params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f, + 6.0f); // ElementWise Min Operation // Min op is a nop for INT8 execution path, as the input tensor @@ -2152,107 +2127,110 @@ tensorflow::Status ConvertRelu6(OpConverterParams* params) { params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ConvertScale(OpConverterParams* params) { +tensorflow::Status ConvertBiasAdd(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2 || !inputs.at(0).is_tensor() || !inputs.at(1).is_weights()) { - return tensorflow::errors::Unimplemented( - "ConvertScale only supports tensorweight: ", node_def.name()); - } - - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - TRT_ShapedWeights weights = inputs.at(1).weights(); - if (params->converter->precision_mode() == FP16MODE) { - weights = ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights()); + return errors::InvalidArgument("Input expects tensor and weights, at ", + node_def.name()); } + if (params->validation_only) return Status::OK(); - TRT_ShapedWeights empty_weights(weights.type_); + nvinfer1::ITensor* tensor = + const_cast(inputs.at(0).tensor()); + const nvinfer1::Dims original_dims = tensor->getDimensions(); TFAttrs attrs(node_def); - - const auto data_format = attrs.get("data_format"); - int channel_index; - const auto dims = tensor->getDimensions(); - if (data_format == "NHWC") { - // 1). NHWC is really N+C - channel_index = dims.nbDims - 1; // batch dimension is implicit here! - } else { - // 2). NCHW is really N+CHW - channel_index = 0; // batch dimension is implicit here! - } + const string data_format = attrs.get("data_format"); + const int channel_index = + (data_format == "NHWC" ? original_dims.nbDims - 1 : 0); nvinfer1::Permutation permutation; - for (int32_t i = 0; i < dims.nbDims; ++i) { - permutation.order[i] = i; - } - - if (channel_index >= 0) { + if (channel_index != 0) { + // Permute the dimensions so that the channel dimension is the first + // dimension. + for (int i = 0; i < original_dims.nbDims; ++i) { + permutation.order[i] = i; + } permutation.order[0] = channel_index; permutation.order[channel_index] = 0; - } else { - return tensorflow::errors::Unimplemented( - "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + VLOG(1) << "ConvertBiasAdd permutation: " + << DebugString(permutation, original_dims.nbDims); } // TensorRT addScale requires input to be of rank 3, we need to apply - // transpose as well as reshape - if (channel_index != 0 || dims.nbDims != 3) { + // transpose as well as reshape. + // TODO(laigd): this doesn't match what the TRT doc says, fix the doc? + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(tensor)); + params->converter->network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); params->converter->MarkQuantizationRangesAsInferrable( - const_cast(tensor), shuffle_layer->getOutput(0)); + tensor, shuffle_layer->getOutput(0)); + // NOTE(laigd): for some reason we need to apply the reshape + // unconditionally. The default shape has nbDims==-1 and it seems the + // behavior is undefined in some cases. nvinfer1::Dims reshape_dims; reshape_dims.nbDims = 3; - reshape_dims.d[0] = 0; // 0 copy from the input - reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input - reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + // 0 means copying from input; -1 means inferring from the rest. + reshape_dims.d[0] = 0; + reshape_dims.d[1] = original_dims.nbDims >= 2 ? 0 : 1; + reshape_dims.d[2] = original_dims.nbDims >= 3 ? -1 : 1; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { - // maybe we do not need this check. concerned about TRT optimization shuffle_layer->setFirstTranspose(permutation); } - shuffle_layer->setReshapeDimensions(reshape_dims); tensor = shuffle_layer->getOutput(0); } + TRT_ShapedWeights weights = inputs.at(1).weights(); + if (params->converter->precision_mode() == FP16MODE) { + weights = ConvertFP32ToFP16(params->weight_store, weights); + } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; if (weights.shape_.d[0] == 1) { mode = nvinfer1::ScaleMode::kUNIFORM; } + TRT_ShapedWeights empty_weights(weights.type_); nvinfer1::IScaleLayer* layer = params->converter->network()->addScale( - *const_cast(tensor), mode, weights.GetTrtWeights(), - empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); + *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), + empty_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - // restore transpose & reshape - if (channel_index != 0 || dims.nbDims != 3) { + // Restore transpose & reshape. + if (channel_index != 0 || original_dims.nbDims != 3) { nvinfer1::IShuffleLayer* shuffle_layer = - params->converter->network()->addShuffle( - *const_cast(output_tensor)); + params->converter->network()->addShuffle(*output_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); - nvinfer1::Dims reshape_dims = dims; - int tmp = reshape_dims.d[channel_index]; - reshape_dims.d[channel_index] = reshape_dims.d[0]; - reshape_dims.d[0] = tmp; + // NOTE: for same reason as mentioned above we need to apply the reshape + // unconditionally. + nvinfer1::Dims reshape_dims = original_dims; + if (channel_index != 0) { + // NOTE: according to NVIDIA dimension types are deprecated, so we don't + // need to copy them back. + reshape_dims.d[channel_index] = original_dims.d[0]; + reshape_dims.d[0] = original_dims.d[channel_index]; + } shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { shuffle_layer->setSecondTranspose(permutation); } params->converter->MarkQuantizationRangesAsInferrable( - const_cast(output_tensor), shuffle_layer->getOutput(0)); + output_tensor, shuffle_layer->getOutput(0)); output_tensor = shuffle_layer->getOutput(0); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); - return tensorflow::Status::OK(); + return Status::OK(); } Status GetTensorDimsWithProtoShape(const Tensor& tensor, @@ -2413,18 +2391,17 @@ tensorflow::Status ConvertIdentity(OpConverterParams* params) { return tensorflow::Status::OK(); } -tensorflow::Status ConvertBinary(OpConverterParams* params) { +Status ConvertBinary(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; if (inputs.size() != 2) { - return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at ", node_def.name()); + return errors::InvalidArgument("Binary ops require two inputs, at ", + node_def.name()); } // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { - return tensorflow::errors::Unimplemented( + return errors::Unimplemented( "Constant folding is falled back to TensorFlow, binary op received " "both input as constant at: ", node_def.name()); @@ -2436,11 +2413,12 @@ tensorflow::Status ConvertBinary(OpConverterParams* params) { // can be fused in more situations. However, most of the benefits of // IScaleLayer are when the layer performs both a shift and a scale, which we // don't do except for convolutions. - // Try to convert into Scale layer first (for better performance) + // + // Try to convert into Scale layer first (for better performance). // Since scale layer supports restricted broadcast policy and op types, we // allow failure and try to handle it through Elementwise op - // (BinaryTensorOpTensor) - Status status = tensorflow::Status::OK(); + // (BinaryTensorOpTensor). + Status status = Status::OK(); if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { status = BinaryTensorOpWeight(params, inputs.at(0).tensor(), inputs.at(1).weights(), false); @@ -2448,7 +2426,10 @@ tensorflow::Status ConvertBinary(OpConverterParams* params) { status = BinaryTensorOpWeight(params, inputs.at(1).tensor(), inputs.at(0).weights(), true); } + // If both input are tensors, or one of them is weights but the conversion + // above failed, try the conversion using BinaryTensorOpTensor. if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { + if (!status.ok()) VLOG(1) << status; status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1)); } return status; @@ -2478,17 +2459,19 @@ tensorflow::Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer; if (node_def.op() == "Rsqrt") { - // We will need a quantization range for intermediate tensor - // if not using calibration. - // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) - // ^ - // need range here + // We will need a quantization range for intermediate tensor if not using + // calibration. + // + // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x) + // ^ + // need range here if (params->converter->precision_mode() == INT8MODE && !params->converter->use_calibration()) { - return tensorflow::errors::Unimplemented( - "Intermediate quantization range cannot be determined without" - " calibration for Rsqrt, consider replacing with " - "Sqrt -> FakeQuant -> Reciprocal ops, at ", node_def.name()); + return errors::Unimplemented( + "Intermediate quantization range cannot be determined without" + " calibration for Rsqrt, consider replacing with " + "Sqrt -> FakeQuant -> Reciprocal ops, at ", + node_def.name()); } layer = params->converter->network()->addUnary( *const_cast(tensor), @@ -3091,40 +3074,49 @@ tensorflow::Status ConvertTopK(OpConverterParams* params) { return tensorflow::Status::OK(); } -void TrtNodeValidator::RegisterOpValidators() { +static void RegisterValidatableOpConverters( + std::unordered_map* registration) { // TODO(laigd): support all op types. - op_validators_["Const"] = ConvertConst; - op_validators_["Transpose"] = ConvertTranspose; - op_validators_["Reshape"] = ConvertReshape; - op_validators_["MatMul"] = ConvertMatMul; + (*registration)["BiasAdd"] = ConvertBiasAdd; + (*registration)["Const"] = ConvertConst; + (*registration)["Transpose"] = ConvertTranspose; + (*registration)["Reshape"] = ConvertReshape; + (*registration)["MatMul"] = ConvertMatMul; + (*registration)["Relu6"] = ConvertRelu6; + + for (auto quantization_op_type : + {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) { + (*registration)[quantization_op_type] = ConvertQuantize; + } + for (auto binary_op_type : + {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) { + (*registration)[binary_op_type] = ConvertBinary; + } +} + +void TrtNodeValidator::RegisterOpValidators() { + RegisterValidatableOpConverters(&op_validators_); } void Converter::RegisterOpConverters() { - // vgg_16 slim implementation + RegisterValidatableOpConverters(&op_registry_); + op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - op_registry_["BiasAdd"] = ConvertScale; - op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. op_registry_["Identity"] = ConvertIdentity; // Identity should be removed op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed - // resnet_50_v1 slim implementation - op_registry_["Add"] = ConvertBinary; - op_registry_["Mul"] = ConvertBinary; - op_registry_["Sub"] = ConvertBinary; op_registry_["Pad"] = ConvertPad; op_registry_["ConcatV2"] = ConvertConcat; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; - op_registry_["Div"] = ConvertBinary; - op_registry_["RealDiv"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; op_registry_["Reciprocal"] = ConvertUnary; op_registry_["Exp"] = ConvertUnary; @@ -3133,27 +3125,14 @@ void Converter::RegisterOpConverters() { op_registry_["Abs"] = ConvertUnary; op_registry_["Neg"] = ConvertUnary; - op_registry_["Transpose"] = ConvertTranspose; - op_registry_["Reshape"] = ConvertReshape; - op_registry_["Sum"] = ConvertReduce; op_registry_["Prod"] = ConvertReduce; op_registry_["Max"] = ConvertReduce; op_registry_["Min"] = ConvertReduce; op_registry_["Mean"] = ConvertReduce; - op_registry_["Maximum"] = ConvertBinary; - op_registry_["Minimum"] = ConvertBinary; op_registry_["Softmax"] = ConvertSoftmax; - op_registry_["MatMul"] = ConvertMatMul; op_registry_["BatchMatMul"] = ConvertBatchMatMul; op_registry_["TopKV2"] = ConvertTopK; - op_registry_["Relu6"] = ConvertRelu6; -# if NV_TENSORRT_MAJOR >= 5 - op_registry_["QuantizeAndDequantizeV2"] = ConvertQuantize; - op_registry_["QuantizeAndDequantizeV3"] = ConvertQuantize; - op_registry_["FakeQuantWithMinMaxVars"] = ConvertQuantize; - op_registry_["FakeQuantWithMinMaxArgs"] = ConvertQuantize; -#endif plugin_converter_ = ConvertPlugin; } @@ -3164,8 +3143,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, - bool use_calibration, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully) { engine->reset(); if (convert_successfully) *convert_successfully = false; @@ -3254,9 +3232,7 @@ tensorflow::Status ConvertGraphDefToEngine( if (convert_successfully) *convert_successfully = true; // Apply user provided quantization ranges to tensors - const bool warn_missing_ranges = (precision_mode == INT8MODE && - !use_calibration); - converter.ApplyQuantizationRanges(warn_missing_ranges); + converter.MaybeApplyQuantizationRanges(); // Build the engine. VLOG(1) << "Starting engine creation"; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 78749124a2690fe029e0a8b503bf6916efb0cae2..54e19b73957bccdae2b23bd3556de9ad00b864e5 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -147,8 +147,7 @@ tensorflow::Status ConvertGraphDefToEngine( const std::vector& input_shapes, Logger* logger, nvinfer1::IGpuAllocator* allocator, TRTInt8Calibrator* calibrator, - TrtUniquePtrType* engine, - bool use_calibration, + TrtUniquePtrType* engine, bool use_calibration, bool* convert_successfully); // Helper class for the segmenter to determine whether an output edge from the @@ -395,8 +394,7 @@ class TrtNodeValidator { // Class to convert TF nodes to TRT network. class Converter { public: - Converter(nvinfer1::INetworkDefinition* trt_network, - int precision_mode, + Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode, bool use_calibration); ////////////////////////////////////////////////////////////////////////////// @@ -442,12 +440,12 @@ class Converter { // This function should be called when we know the quantization range of a // tensor, either from a quantize/dequantize node or when the output is a // fixed range (e.g. SoftMax, Relu6, Sigmoid). - void ProvideQuantizationRange(nvinfer1::ITensor* tensor, - float min_range, float max_range); + void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range, + float max_range); // Should be called when full TRT network has been constructed and before // building the engine. - void ApplyQuantizationRanges(bool warn_missing_ranges); + void MaybeApplyQuantizationRanges(); // Below are helper methods for op converters to add different layers to the // TRT network. @@ -464,6 +462,13 @@ class Converter { const nvinfer1::Dims& dims, const nvinfer1::ITensor** tensor); + // Return OK if the broadcast scheme is supported and compute the shapes after + // broadcasting. + Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims) const; + private: // Verify the provided batch_size is consistent with batch_size_ and update it // if necessary. @@ -482,10 +487,10 @@ class Converter { void RegisterOpConverters(); void PropagateQuantizationRanges(); - + // Gets the min and max value in a TRT_ShapedWeights - Status GetWeightRange(const TRT_ShapedWeights& weights, - float* out_min, float* out_max) const; + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const; // Registered op converters by op type. std::unordered_map op_registry_; @@ -503,7 +508,7 @@ class Converter { TrtWeightStore weight_store_; // During conversion, this table is populated with quantization ranges per - // tensor. ApplyQuantizationRanges() will use this table to set the TensorRT + // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT // quantization ranges. Since TRT only supports symmetric ranges, we will // store the range as a single float = max(abs(min_range), abs(max_range)). // Range refers to the floating point values, e.g. min_range = 0.0f, max_range @@ -514,8 +519,8 @@ class Converter { // first tensor to second tensor. PropagateQuantizationRanges() will propagate // known ranges from quantization_ranges_ across these edges, adding the new // ranges to quantization_ranges_ so that they can be applied in - // ApplyQuantizationRanges(). - std::vector> + // MaybeApplyQuantizationRanges(). + std::vector> quantization_infer_; const int precision_mode_; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index 257394c0a308cdeced07b81929a69d583fbe0c40..603c4f7b5e5af8df7f81484c715675968f5da695 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -35,7 +35,10 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/public/session.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -47,7 +50,9 @@ namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrCat; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, @@ -69,6 +74,32 @@ nvinfer1::Dims GetTestDims(const std::vector& d) { return dims; } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { + switch (tf_dtype) { + case DT_FLOAT: + return nvinfer1::DataType::kFLOAT; + case DT_HALF: + return nvinfer1::DataType::kHALF; + case DT_INT32: + return nvinfer1::DataType::kINT32; + default: + QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); + } +} + +DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return DT_FLOAT; + case nvinfer1::DataType::kHALF: + return DT_HALF; + case nvinfer1::DataType::kINT32: + return DT_INT32; + default: + QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); + } +} + NodeDef MakeNodeDef(const string& name, const string& op, const std::vector& inputs) { NodeDef node_def; @@ -111,6 +142,15 @@ bool TrtDimsEqualsArray(const std::vector& lhs, return TrtDimsEquals(GetTestDims(lhs), rhs); } +// TODO(laigd): define a parameterized matcher that can compare against the +// vector. +void ExpectTrtDimsEqualsArray(const std::vector& lhs, + const nvinfer1::Dims& rhs) { + EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs)) + << "expected: " << DebugString(GetTestDims(lhs)) << "\n" + << " actual: " << DebugString(rhs); +} + bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs, const TRT_ShapedWeights& rhs) { return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ && @@ -121,8 +161,7 @@ template void ValidateWeights(const TRT_ShapedWeights& weights, const std::vector& expected_dims, const std::vector& expected_value) { - EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_)) - << weights.DebugString(); + ExpectTrtDimsEqualsArray(expected_dims, weights.shape_); ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString(); const T* actual_values = static_cast(weights.GetValues()); for (int i = 0; i < expected_value.size(); ++i) { @@ -272,9 +311,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(1, ptr->batch_size()); } EXPECT_EQ(&itensor, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } } @@ -293,9 +330,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { EXPECT_EQ(false, ptr->is_weights()); EXPECT_EQ(1, ptr->batch_size()); EXPECT_NE(nullptr, ptr->tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims()); } } // Test constructor with TRT_ShapedWeights argument. @@ -312,9 +347,7 @@ TEST(TRT_TensorOrWeights_Test, Basic) { nvinfer1::Dims dims; dims.nbDims = 0; - EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims())) - << "- expected: " << DebugString(dims) - << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims()); + ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims()); } } } @@ -348,34 +381,50 @@ TEST_F(ValidatorTest, ConvertToTensorOrWeights) { graph_properties, &output)); ValidateWeights(output.weights(), {2}, {1.0, 2.0}); } - // Convert non-Const. We test the case where the non-batch dimemsion is - // unknown as well, to make sure the validator allows that. - for (const int32 non_batch_dim : {-1, 2}) { - const int32 batch_size = 12; + // Helper method to run ConvertToTensorOrWeights() with predefined parameters. + auto convert_to_tensor_or_weights = [this](const std::vector& dims, + TRT_TensorOrWeights* output) { Scope s = Scope::NewRootScope(); - ops::Placeholder::Attrs attrs; - TF_EXPECT_OK(TensorShapeUtils::MakeShape( - std::vector{batch_size, non_batch_dim}, &attrs.shape_)); + const auto attrs = ops::Placeholder::Shape(PartialTensorShape{dims}); auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT, attrs); auto add = ops::Add(s.WithOpName("add"), feed, feed); grappler::GrapplerItem item; TF_EXPECT_OK(s.ToGraphDef(&item.graph)); - grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - - auto& node_def = add.operation.node()->def(); + const NodeDef& node_def = add.operation.node()->def(); + return this->ConvertToTensorOrWeights(node_def, /*output_port=*/0, + graph_properties, output); + }; + // Convert non-Const with #dims > nvinfer1::Dims::MAX_DIMS+1. + { TRT_TensorOrWeights output; - ExpectStatus(ConvertToTensorOrWeights(node_def, /*output_port=*/0, - graph_properties, &output)); + ExpectStatus( + convert_to_tensor_or_weights( + std::vector(nvinfer1::Dims::MAX_DIMS + 2, 1), &output), + error::OUT_OF_RANGE, "Input tensor rank is greater than 9"); + } + // Convert non-Const with #dims < 2. + { + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({1}, &output), error::INVALID_ARGUMENT, + "Input tensor with rank<2 is not supported since the first dimension " + "is treated as batch dimension by TRT"); + } + // Convert non-Const. We test the case where the non-batch dimemsion is + // unknown as well, to make sure the validator allows that. + for (const int32 non_batch_dim : {-1, 2}) { + const int32 batch_size = 12; + TRT_TensorOrWeights output; + ExpectStatus( + convert_to_tensor_or_weights({batch_size, non_batch_dim}, &output)); EXPECT_EQ(true, output.is_tensor()); EXPECT_EQ(batch_size, output.batch_size()); EXPECT_NE(nullptr, output.tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims())) - << "- expected: {" << non_batch_dim << "} \n vs\n" - << "- actual: " << DebugString(output.GetTrtDims()); + ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()); } } @@ -526,9 +575,9 @@ TEST_F(ConverterTest, AddAndGetInputs) { EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType()); EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType()); - EXPECT_TRUE(TrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions())); - EXPECT_TRUE(TrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions())); + ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()); + ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()); } TEST_F(ConverterTest, RenameAndMarkOutputTensors) { @@ -574,7 +623,7 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { {{"my_op", "my_output"}, {"my_op:1", "my_output_1"}})); EXPECT_EQ(2, output_tensors.size()); for (auto output_tensor : output_tensors) { - EXPECT_TRUE(TrtDimsEqualsArray({2, 1}, output_tensor->getDimensions())); + ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions()); } EXPECT_EQ("my_output", string(output_tensors[0]->getName())); EXPECT_EQ("my_output_1", string(output_tensors[1]->getName())); @@ -599,8 +648,7 @@ TEST_F(ConverterTest, TransposeTensor) { // OK. TF_EXPECT_OK( converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { @@ -612,7 +660,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Shape size doesn't match. ExpectStatus(converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}), &output_tensor), - error::INVALID_ARGUMENT, "Reshape shapes are not compatible."); + error::INVALID_ARGUMENT, "Reshape shapes are not compatible"); // TODO(aaroey): we should check the case where uninferred dimensions are not // an exact divisor of input dim ensions, e.g. for dims {-1, 7}. @@ -620,14 +668,12 @@ TEST_F(ConverterTest, PrepareTensorForShape_Tensor) { // Infer shape, ok. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({-1, 2}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({15, 2}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()); // Regular shape. TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, PrepareTensorForShape_Weights) { @@ -637,8 +683,7 @@ TEST_F(ConverterTest, PrepareTensorForShape_Weights) { const nvinfer1::ITensor* output_tensor = nullptr; TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}), &output_tensor)); - EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions())) - << DebugString(*output_tensor); + ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -678,51 +723,57 @@ TEST_F(ConverterTest, AddAndGetTensorOrWeights) { "tensor/weights my_tensor already exist"); } -TEST_F(ConverterTest, GetWeightRange) { +template +void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) { TRT_ShapedWeights weights = - weight_store_->GetTempWeights(DT_FLOAT, GetTestDims({2, 3})); - const std::vector values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + weight_store->GetTempWeights(DataTypeToEnum::v(), GetTestDims({2, 3})); + const std::vector values = {T(3), T(1), T(2), T(6), T(5), T(4)}; memcpy(const_cast(weights.GetValues()), values.data(), weights.size_bytes()); float out_min = 0.0f; float out_max = 0.0f; - TF_EXPECT_OK(GetWeightRange(weights, &out_min, &out_max)); + TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max)); EXPECT_EQ(1.0f, out_min); EXPECT_EQ(6.0f, out_max); } +TEST_F(ConverterTest, GetWeightRange) { + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); + TestGetWeightRange(this, weight_store_); +} + TEST_F(ConverterTest, ProvideQuantizationRange) { FakeITensor fake_tensor; // Assymetric range converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f); - EXPECT_EQ(quantization_ranges()[&fake_tensor], 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f); - EXPECT_EQ(quantization_ranges()[&fake_tensor], 6.0f); + EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]); converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f); - EXPECT_EQ(quantization_ranges()[&fake_tensor], 8.0f); + EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]); converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f); - EXPECT_EQ(quantization_ranges()[&fake_tensor], 8.123f); + EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]); // Symmetric range converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f); - EXPECT_EQ(quantization_ranges()[&fake_tensor], 6.123f); + EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]); } -TEST_F(ConverterTest, ApplyQuantizationRanges) { +TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { // input -> infer1 -> infer2 -> infer3 - FakeITensor input; - FakeITensor infer_1; - FakeITensor infer_2; - FakeITensor infer_3; + FakeITensor input, infer_1, infer_2, infer_3; FakeITensor not_infer; - converter_->ProvideQuantizationRange(&input, -5.0f, 5.0f); - converter_->ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); - converter_->MarkQuantizationRangesAsInferrable(&input, &infer_1); - converter_->MarkQuantizationRangesAsInferrable(&infer_1, &infer_2); - converter_->MarkQuantizationRangesAsInferrable(&infer_2, &infer_3); + Converter int8_converter(/*trt_network=*/nullptr, INT8MODE, + /*use_calibration=*/true); + int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f); + int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f); + int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2); + int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3); // Input range should be inferred along the chain and applied to tensors. - converter_->ApplyQuantizationRanges(/*warn_missing_ranges=*/false); + int8_converter.MaybeApplyQuantizationRanges(); #if NV_TENSORRT_MAJOR >= 5 EXPECT_EQ(input.getDynamicRange(), 5.0f); EXPECT_EQ(infer_1.getDynamicRange(), 5.0f); @@ -733,27 +784,117 @@ TEST_F(ConverterTest, ApplyQuantizationRanges) { } TEST_F(ConverterTest, PropagateQuantizationRanges) { - // input <-> infer1 <-> infer2 <-> infer3 - FakeITensor input; - FakeITensor infer_1; - FakeITensor infer_2; - FakeITensor infer_3; + // infer0 <-> infer1 <-> infer2 <-> infer3 + // | + // infer4 <-> infer5 + FakeITensor infer[6]; FakeITensor not_infer; - converter_->ProvideQuantizationRange(&input, -5.0f, 5.0f); - converter_->MarkQuantizationRangesAsInferrable(&input, &infer_1); - converter_->MarkQuantizationRangesAsInferrable(&infer_1, &infer_2); - converter_->MarkQuantizationRangesAsInferrable(&infer_3, &infer_2); + converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f); + converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]); + converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]); // Input range should be inferred along the chain. PropagateQuantizationRanges(); auto ranges = quantization_ranges(); - EXPECT_EQ(ranges[&input], 5.0f); - EXPECT_EQ(ranges[&infer_1], 5.0f); - EXPECT_EQ(ranges[&infer_2], 5.0f); - EXPECT_EQ(ranges[&infer_3], 5.0f); + for (int i = 0; i < 6; ++i) { + EXPECT_EQ(5.0f, ranges[&infer[i]]); + } EXPECT_EQ(ranges.count(¬_infer), 0); } +TEST_F(ConverterTest, GetTrtBroadcastShape) { + const bool kIsTensor = true; + const bool kIsNotTensor = false; + auto symmetric_test = [this](const std::vector& operand_1_shape, + const std::vector& operand_2_shape, + const bool operand_1_is_tensor, + const bool operand_2_is_tensor, + const std::vector& expected_operand_1_shape, + const std::vector& expected_operand_2_shape, + error::Code expected_code = error::OK, + const char* expected_error_msg_substr = nullptr, + const int operand_1_batch_size = -1, + const int operand_2_batch_size = -1) { + auto create_tensor_or_weights = [](const std::vector& shape, + bool is_tensor, int batch_size = -1) { + if (is_tensor) { + return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT, + GetTestDims(shape), batch_size}; + } + TRT_ShapedWeights weights; + weights.shape_ = GetTestDims(shape); + return TRT_TensorOrWeights(weights); + }; + + nvinfer1::Dims operand_1_new_dims, operand_2_new_dims; + TRT_TensorOrWeights operand_1 = create_tensor_or_weights( + operand_1_shape, operand_1_is_tensor, operand_1_batch_size); + TRT_TensorOrWeights operand_2 = create_tensor_or_weights( + operand_2_shape, operand_2_is_tensor, operand_2_batch_size); + + // operand_1 broadcast operand_2 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + // operand_2 broadcast operand_1 + ExpectStatus( + this->converter_->GetTrtBroadcastShape( + operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims), + expected_code, expected_error_msg_substr); + if (expected_code == error::OK) { + ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims); + ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims); + } + }; + + // Both inputs are weights. + symmetric_test( + {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT, + "Broadcasting requires at least one of the operands be tensors"); + + // One tensor and one weights. + symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2}); + symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2}); + symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1}); + symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {1, 2, 3}); + symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1}, + {2, 3, 4}); + symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); + symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, "Infeasible broadcast scheme", + /*operand_1_batch_size=*/2); + symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); + + // Both inputs are tensors. + symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 3 vs broadcast #dims 4)"); + symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4}, + {2, 1, 4}); + symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {}, + error::INVALID_ARGUMENT, + "Broadcasting beyond batch dimension is not supported " + "(tensor #dims 4 vs broadcast #dims 5)"); +} + // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -791,8 +932,12 @@ class OpConverterTest : public ::testing::Test { validator_inputs_.clear(); } - void BuildAndRun(const char* input_name, const std::vector& input_data, - const char* output_name, std::vector* output_data) { + // TODO(laigd): test fp16 and int8 support. + template + void BuildAndRun( + const std::vector>>& + input_data, + const char* output_name, std::vector* output_data) { // Mark the output tensor as TRT engine output. TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors( {{string(output_name), string(output_name)}})); @@ -803,25 +948,33 @@ class OpConverterTest : public ::testing::Test { CHECK_NOTNULL(engine_.get()); // Execute the TRT engine. - const int input_size = input_data.size() * sizeof(float); - const int output_size = output_data->size() * sizeof(float); - const int input_index = engine_->getBindingIndex(input_name); - const int output_index = engine_->getBindingIndex(output_name); + ASSERT_LE(input_data.size() + 1, 3); + void* buffers[3]; + for (const auto name_and_data : input_data) { + const int input_size = name_and_data.second.size() * sizeof(T); + const int input_index = engine_->getBindingIndex(name_and_data.first); + ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + ASSERT_EQ( + 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(), + input_size, cudaMemcpyHostToDevice, stream_)); + } - ASSERT_EQ(engine_->getNbBindings(), 2); - void* buffers[2]; - ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size)); + const int output_size = output_data->size() * sizeof(T); + const int output_index = engine_->getBindingIndex(output_name); ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size)); - ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input_data.data(), - input_size, cudaMemcpyHostToDevice, stream_)); + + ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1); + TrtUniquePtrType execution_context( engine_->createExecutionContext()); execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr); ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index], output_size, cudaMemcpyDeviceToHost, stream_)); cudaStreamSynchronize(stream_); - ASSERT_EQ(0, cudaFree(buffers[input_index])); - ASSERT_EQ(0, cudaFree(buffers[output_index])); + + for (int i = 0; i < input_data.size() + 1; ++i) { + ASSERT_EQ(0, cudaFree(buffers[i])); + } } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -836,18 +989,7 @@ class OpConverterTest : public ::testing::Test { void AddTestTensor( const char* name, const std::vector& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { - DataType tf_dtype = DT_FLOAT; - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - tf_dtype = DT_FLOAT; - break; - case nvinfer1::DataType::kINT32: - tf_dtype = DT_INT32; - break; - default: - ASSERT_TRUE(false) << "Unexpected data type " - << static_cast(trt_dtype); - } + DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); attrs.shape_.InsertDim(0, batch_size); @@ -940,6 +1082,11 @@ class OpConverterTest : public ::testing::Test { TrtUniquePtrType network_; TrtUniquePtrType engine_; cudaStream_t stream_; + // Used to create placeholders with shape and data type information. The + // created placeholders will be used as inputs to the node to be verified, + // thus we need the shape and data type information to get a non-empty + // GraphProperties. + // TODO(laigd): consider use this Scope to create the NodeDef to verify. Scope scope_; std::unordered_map validator_inputs_; }; @@ -1063,15 +1210,15 @@ TEST_F(OpConverterTest, ConvertTranspose) { Reset(); AddTestTensor("input", {1, 2, 3}); AddTestWeights("weights", {4}, {0, 3, 1, 2}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_transpose", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6)); } } @@ -1153,15 +1300,15 @@ TEST_F(OpConverterTest, ConvertReshape) { Reset(); AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size); AddTestWeights("weights", {4}, ok_params[i].shape); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions())) - << output.DebugString(); + ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()); std::vector output_data(6); - BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_reshape", &output_data); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape", + &output_data); EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6)); } } @@ -1175,15 +1322,14 @@ TEST_F(OpConverterTest, ConvertMatMul) { "Input expects tensor and weights, at my_matmul"); } - // Get the NodeDef for Reshape. + // Get the NodeDef for MatMul. auto get_matmul_nodedef = [](DataType dtype, bool transpose_a, bool transpose_b) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), dtype); auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); - ops::MatMul::Attrs matmul_attrs; - matmul_attrs.transpose_a_ = transpose_a; - matmul_attrs.transpose_b_ = transpose_b; + const auto matmul_attrs = + ops::MatMul::TransposeA(transpose_a).TransposeB(transpose_b); auto matmul = ops::MatMul(s.WithOpName("my_matmul"), input, weights, matmul_attrs); return matmul.operation.node()->def(); @@ -1199,82 +1345,499 @@ TEST_F(OpConverterTest, ConvertMatMul) { node_def, error::UNIMPLEMENTED, "Data type is not supported, for node my_matmul got int32"); } - { - // transpose_a is set. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "transpose_a is not supported for TensorRT FullyConnected"); + // transpose_a is set. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/true, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "transpose_a is not supported for TensorRT FullyConnected"); + } + // OK. + for (bool transpose_b : {false, true}) { + Reset(); + NodeDef node_def = + get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); + AddTestTensor("input", {2}, /*batch_size=*/1); + AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + BuildAndRun({{"input", {0, 1}}}, "my_matmul", &output_data); + if (transpose_b) { + EXPECT_THAT(output_data, ElementsAre(1, 3)); + } else { + EXPECT_THAT(output_data, ElementsAre(2, 3)); } } - { - // OK. - for (bool transpose_b : {false, true}) { - Reset(); - NodeDef node_def = - get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b); - AddTestTensor("input", {2}, /*batch_size=*/1); - AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); - RunConversion(node_def); +} + +template +void TestConvertBiasAdd(OpConverterTest* test) { + // Get the NodeDef for BiasAdd. + auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format); + auto biasadd = + ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs); + return biasadd.operation.node()->def(); + }; + + typedef typename EnumToDataType::Type CType; + for (const string& data_format : {"NHWC", "NCHW"}) { + for (const int trt_input_rank : {1, 2, 3, 4}) { + test->Reset(); + NodeDef node_def = get_biasadd_nodedef(data_format); + + // Add input, dims_array will be like {2, 1, ..., 1, 3} + std::vector dims_array(trt_input_rank, 1); + if (trt_input_rank == 1) { + dims_array[0] = (data_format == "NHWC" ? 3 : 2); + } else { + dims_array[0] = 2; + dims_array[trt_input_rank - 1] = 3; + } + test->AddTestTensor("input", dims_array, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + + // Add bias weights. + const int channel_size = (data_format == "NHWC" ? 3 : 2); + std::vector bias(channel_size); + for (int i = 0; i < channel_size; ++i) { + bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...} + } + test->AddTestWeights("weights", {channel_size}, bias); + + // Run the conversion. + test->RunValidationAndConversion(node_def); TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output)); + TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output)); EXPECT_TRUE(output.is_tensor()); - EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions())) - << output.DebugString(); - - std::vector output_data(2); - BuildAndRun("input", {0, 1}, "my_matmul", &output_data); - if (transpose_b) { - EXPECT_THAT(output_data, ElementsAre(1, 3)); + ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()); + + // Build and run the engine. + const int num_input = TrtDimsNumElements(GetTestDims(dims_array)); + ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), + num_input); + std::vector output_data(num_input); + test->BuildAndRun( + {{"input", std::vector(num_input, CType(0))}}, "my_biasadd", + &output_data); + if (trt_input_rank == 1) { + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2))); + } } else { - EXPECT_THAT(output_data, ElementsAre(2, 3)); + if (data_format == "NHWC") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3), + CType(1), CType(2), CType(3))); + } else { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1), + CType(2), CType(2), CType(2))); + } } } } } -TEST_F(OpConverterTest, ConvertQuantize) { +TEST_F(OpConverterTest, ConvertBiasAdd) { { // Input list is empty, should fail. - NodeDef node_def = - MakeNodeDef("my_quantize", "QuantizeAndDequantizeV2", {}); - RunConversion( + NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {}); + RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, - "Invalid number of inputs for QuantizeAndDequantizeV2, at my_quantize"); + "Input expects tensor and weights, at my_biasadd"); + } + + // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test + // DT_INT32 type here. + TestConvertBiasAdd(this); + TestConvertBiasAdd(this); +} + +template +NodeDef GetBinaryOpNodeDef(const string& input_name_l, + const string& input_name_r, DataType dtype) { + Scope s = Scope::NewRootScope(); + auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype); + auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype); + auto op = OpType(s.WithOpName("my_binary"), input_l, input_r); + return op.operation.node()->def(); +} + +void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) { + bool element_wise_layer_found = false; + bool scale_layer_found = false; + for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) { + nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i); + if (dynamic_cast(layer)) { + scale_layer_found = true; + } else if (dynamic_cast(layer)) { + element_wise_layer_found = true; + } + } + EXPECT_EQ(expect_scale_layer, scale_layer_found); + EXPECT_NE(expect_scale_layer, element_wise_layer_found); +} + +template +void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + for (auto swap_inputs : {false, true}) { + test->Reset(); + NodeDef node_def; + if (swap_inputs) { + node_def = GetBinaryOpNodeDef("weights", "input", dtype); + } else { + node_def = GetBinaryOpNodeDef("input", "weights", dtype); + } + + const std::vector operand1{CType(3), CType(7.5)}; + const std::vector operand2{CType(2), CType(3)}; + + // It requires the dims to be at least of rank 3 to apply an IScaleLayer. + test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", /*dims=*/{1, 1, 2}, + /*values=*/swap_inputs ? operand1 : operand2); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(2); + test->BuildAndRun( + {{"input", + /*input_data=*/swap_inputs ? operand2 : operand1}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5))); + } else { + ASSERT_TRUE(false); + } + } +} + +template +void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10), CType(20)}; + // There are two types of valid dim pairs which requires channel-wise + // broadcasting: + // - input dims (X Y Z) vs weights dims (X 1 1) + // - input dims (X Y Z) vs weights dims (Z) + // Here X=Z=2 and Y=1. + for (auto weights_dims : std::vector>{{2, 1, 1}, {2}}) { + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", weights_dims, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + if (weights_dims.size() == 1) { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(22), CType(13), CType(24))); + } else { + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(23), CType(24))); + } + } +} + +template +void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + const std::vector input{CType(1), CType(2), CType(3), CType(4)}; + const std::vector weights{CType(10)}; + test->Reset(); + test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestWeights("weights", {1, 1, 1, 1}, weights); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor. + CheckAddedLayers(test, /*expect_scale_layer=*/true); + + // Check the dims of the output ITensor. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + test->BuildAndRun({{"input", input}}, "my_binary", &output_data); + EXPECT_THAT(output_data, + ElementsAre(CType(11), CType(12), CType(13), CType(14))); +} + +template +void TestBinaryTensorOpWeightFallback(OpConverterTest* test, + const std::vector& input_dims, + const std::vector& weights_dims, + error::Code code = error::OK, + const char* error_msg_substr = nullptr, + const int input_batch_size = 1) { + const DataType dtype = DT_FLOAT; + typedef typename EnumToDataType::Type CType; + const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims)); + const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims)); + + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input", "weights", dtype); + test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size, + TfDataTypeToTrt(dtype)); + test->AddTestWeights( + "weights", /*dims=*/weights_dims, + /*values=*/std::vector(num_weights, CType(1))); + test->RunValidationAndConversion(node_def, code, error_msg_substr); + if (code != error::OK) return; + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + + // Check the dims of the output ITensor. + std::vector expected_output_dims = input_dims; + for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1; + i >= 0 && j >= 0; --i, --j) { + if (expected_output_dims[i] == 1) { + expected_output_dims[i] = weights_dims[j]; + } + } + ExpectTrtDimsEqualsArray(expected_output_dims, + output.tensor()->getDimensions()); + + // Check the result of running the engine. + const int expected_num_outputs = + TrtDimsNumElements(GetTestDims(expected_output_dims)); + std::vector output_data(expected_num_outputs); + test->BuildAndRun( + {{"input", + /*input_data=*/std::vector(num_inputs, CType(2))}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(3)))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, ElementsAreArray(std::vector( + expected_num_outputs, CType(1)))); + } else { + ASSERT_TRUE(false); + } +} + +template +void TestBinaryTensorOpTensor(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + test->Reset(); + const NodeDef node_def = + GetBinaryOpNodeDef("input1", "input2", dtype); + test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1, + TfDataTypeToTrt(dtype)); + test->RunValidationAndConversion(node_def); + + // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight. + CheckAddedLayers(test, /*expect_scale_layer=*/false); + + // Check output dims. + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output)); + EXPECT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); + + std::vector output_data(4); + // After broadcasting first input becomes {3, 6, 3, 6} and second input + // becomes {2, 3, 2, 3}. + test->BuildAndRun( + {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}}, + "my_binary", &output_data); + if (node_def.op() == "Add") { + EXPECT_THAT(output_data, + ElementsAre(CType(5), CType(8), CType(6), CType(9))); + } else if (node_def.op() == "Sub") { + EXPECT_THAT(output_data, + ElementsAre(CType(1), CType(4), CType(0), CType(3))); + } else if (node_def.op() == "Mul") { + EXPECT_THAT(output_data, + ElementsAre(CType(6), CType(12), CType(9), CType(18))); + } else if (node_def.op() == "Div") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "RealDiv") { + EXPECT_THAT(output_data, + ElementsAre(CType(1.5), CType(3), CType(1), CType(2))); + } else if (node_def.op() == "Minimum") { + EXPECT_THAT(output_data, + ElementsAre(CType(2), CType(2), CType(3), CType(3))); + } else if (node_def.op() == "Maximum") { + EXPECT_THAT(output_data, + ElementsAre(CType(3), CType(6), CType(3), CType(6))); + } else { + ASSERT_TRUE(false); + } +} + +TEST_F(OpConverterTest, ConvertBinary) { + // Input size doesn't match, should fail. + for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) { + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"}); + AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT); + RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, + "Binary ops require two inputs, at my_add"); + } + { + // Both inputs are weights. + Reset(); + NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"}); + AddTestWeights("weights1", {1}, {1}); + AddTestWeights("weights2", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: my_add"); + } + + // Test BinaryTensorOpWeight() without broadcasting. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#if 0 + // TODO(b/119560144): it doesn't support FP16 constants and the following test + // will fail. + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); + TestBinaryTensorOpWeightNoBroadcast(this); +#endif + + // Test BinaryTensorOpWeight() with channel-wise broadcasting. + TestBinaryTensorOpWeightWithChannelWiseBroadcast(this); + + // Test BinaryTensorOpWeight() with uniformly broadcasting. + TestBinaryTensorOpWeightWithUniformlyBroadcast(this); + + // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor(). + // Unsupported op. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1}); + // Rank of input tensor dimension <3. + TestBinaryTensorOpWeightFallback(this, {1, 1}, {1}); + // Broadcast on batch dimension, should fail. + TestBinaryTensorOpWeightFallback( + this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT, + "Unsupported binary op broadcast scheme for op my_binary", + /*input_batch_size=*/2); + // Incompatible dims with per-channel mode. + TestBinaryTensorOpWeightFallback(this, {1, 1, 1}, {1, 2, 1}); + // Incompatible dims. + TestBinaryTensorOpWeightFallback(this, {1, 2, 1}, {2}); + + // Test BinaryTensorOpTensor() with broadcasting. + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); + TestBinaryTensorOpTensor(this); +} + +TEST_F(OpConverterTest, ConvertQuantize) { + for (const string& op : + {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars", + "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_quantize", op, {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + StrCat("Invalid number of inputs for ", op, ", at my_quantize") + .c_str()); } { // FakeQuantWithMinMaxArgs attributes are empty, should fail. NodeDef node_def = MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); AddTestTensor("input", {1, 2, 3}); - RunConversion(node_def, error::INVALID_ARGUMENT, - "Min or max attribute not found for FakeQuantWithMinMaxArgs " - "at my_quantize"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min or max attribute not found for FakeQuantWithMinMaxArgs " + "at my_quantize"); } { // FakeQuantWithMinMaxArgs ranges set via attributes, ok. Reset(); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - ops::FakeQuantWithMinMaxArgs::Attrs quantize_attrs; - quantize_attrs.min_ = -6.0f; - quantize_attrs.max_ = 6.0f; + auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"), input, quantize_attrs); const NodeDef& node_def = quantize.operation.node()->def(); AddTestTensor("input", {1, 2, 3}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); EXPECT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); - EXPECT_EQ(ranges.count(output.tensor()), 1); - EXPECT_EQ(ranges[output.tensor()], 6.0f); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); } { // FakeQuantWithMinMaxVars ranges set via inputs, ok. @@ -1289,13 +1852,13 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestTensor("input", {1, 2, 3}); AddTestWeights("weights_min", {1}, {-6.0f}); AddTestWeights("weights_max", {1}, {6.0f}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); EXPECT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); - EXPECT_EQ(ranges.count(output.tensor()), 1); - EXPECT_EQ(ranges[output.tensor()], 6.0f); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); } { // QuantizeAndDequantizeV2 ranges set via inputs, ok. @@ -1310,13 +1873,31 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestTensor("input", {1, 2, 3}); AddTestWeights("weights_min", {1}, {-6.0f}); AddTestWeights("weights_max", {1}, {6.0f}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); EXPECT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); - EXPECT_EQ(ranges.count(output.tensor()), 1); - EXPECT_EQ(ranges[output.tensor()], 6.0f); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); + } + { + // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. + Reset(); + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); + auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); + auto quantize = ops::QuantizeAndDequantizeV2( + s.WithOpName("my_quantize"), input, weights_min, weights_max); + const NodeDef& node_def = quantize.operation.node()->def(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights_min", {1}); + AddTestTensor("weights_max", {1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " + "tensors, at my_quantize"); } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. @@ -1333,31 +1914,13 @@ TEST_F(OpConverterTest, ConvertQuantize) { AddTestWeights("weights_min", {1}, {-6.0f}); AddTestWeights("weights_max", {1}, {6.0f}); AddTestWeights("num_bits", {1}, {8}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output)); EXPECT_TRUE(output.is_tensor()); auto ranges = quantization_ranges(); - EXPECT_EQ(ranges.count(output.tensor()), 1); - EXPECT_EQ(ranges[output.tensor()], 6.0f); - } - { - // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. - Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); - auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT); - auto quantize = ops::QuantizeAndDequantizeV2( - s.WithOpName("my_quantize"), input, weights_min, weights_max); - const NodeDef& node_def = quantize.operation.node()->def(); - AddTestTensor("input", {1, 2, 3}); - AddTestTensor("weights_min", {1}); - AddTestTensor("weights_max", {1}); - RunConversion( - node_def, error::INVALID_ARGUMENT, - "Min and max inputs for QuantizeAndDequantizeV2 must be weights not " - "tensors, at my_quantize"); + EXPECT_EQ(1, ranges.count(output.tensor())); + EXPECT_EQ(6.0f, ranges[output.tensor()]); } } @@ -1365,21 +1928,29 @@ TEST_F(OpConverterTest, ConvertRelu6) { { // Input list is empty, should fail. NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {}); - RunConversion(node_def, error::INVALID_ARGUMENT, - "Invalid number of inputs for Relu6, at my_relu6"); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Invalid number of inputs for Relu6, at my_relu6"); } // Get the NodeDef for Relu6. Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input); - const NodeDef& node_def = relu6.operation.node()->def(); - + const NodeDef node_def = relu6.operation.node()->def(); + { + // Input is weights, should fail. + Reset(); + AddTestWeights("input", {1}, {1.0f}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Relu6 is only implemented for tensors, not weights, at my_relu6"); + } { // Clip tensor values and set quantization ranges, ok. Reset(); AddTestTensor("input", {1, 2, 3}); - RunConversion(node_def); + RunValidationAndConversion(node_def); TRT_TensorOrWeights output; TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output)); EXPECT_TRUE(output.is_tensor()); @@ -1387,17 +1958,10 @@ TEST_F(OpConverterTest, ConvertRelu6) { EXPECT_EQ(ranges[output.tensor()], 6.0f); std::vector output_data(6); - BuildAndRun("input", {-100, -1, 0, 3, 5, 9}, "my_relu6", &output_data); + BuildAndRun({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6", + &output_data); EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6)); } - { - // Input is weights, should fail. - Reset(); - AddTestWeights("input", {1, 2, 3}, {-100, -1, 0, 3, 5, 9}); - RunConversion( - node_def, error::UNIMPLEMENTED, - "Relu6 is only implemented for tensors, not weights, at my_relu6"); - } } } // namespace convert diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 780343d6623ceb0cf3d0f0ebfc30aa669c280f44..1e907e0d2a669b2bef5fc6ca0822c1e6049c7018 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -126,8 +126,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, context->GetAttr("use_calibration", &use_calibration_)); - calibration_mode_ = (use_calibration_ && - (precision_mode_ == INT8MODE && calibration_data.size() == 0)); + calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE && + calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -499,8 +499,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, // means calibration_mode_ is true and this path won't get executed. auto status = convert::ConvertGraphDefToEngine( segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, - &logger, allocator, calibrator_.get(), &engine, - use_calibration_, &convert_successfully); + &logger, allocator, calibrator_.get(), &engine, use_calibration_, + &convert_successfully); if (!status.ok()) { if (convert_successfully) { // This means it fail to build the engine even when the network is built diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 95f47c90148a1ab9773164fce1cb040235ecad5b..f0945087d92cbc08940699b760b1d06e5539bf7a 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -63,20 +63,20 @@ class TrtPrecisionMode(object): return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] -def tensorrt_rewriter_config(rewriter_config=None, - max_batch_size=1, - max_workspace_size_bytes=2 << 20, - precision_mode=TrtPrecisionMode.FP32, - minimum_segment_size=3, - is_dynamic_op=False, - maximum_cached_engines=1, - cached_engine_batch_sizes=None, - use_calibration=True): +def get_tensorrt_rewriter_config(rewriter_config=None, + max_batch_size=1, + max_workspace_size_bytes=2 << 20, + precision_mode=TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batch_sizes=None, + use_calibration=True): """Returns a RewriterConfig proto for TRT transformation. Args: - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. + rewriter_config: a template RewriterConfig proto used to create a + TRT-enabled RewriterConfig. If None, it will use a default one. max_batch_size: max size for the input batch max_workspace_size_bytes: the maximum GPU temporary memory which the TRT engine can use at execution time. This corresponds to the 'workspaceSize' @@ -96,15 +96,15 @@ def tensorrt_rewriter_config(rewriter_config=None, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. - use_calibration: this argument is ignored if precision_mode is not INT8. - if set to True, a calibration graph will be created to calibrate the - missing ranges. The calibration graph must be converted to an inference - graph using calib_graph_to_infer_graph() after running calibration. - if set to False, quantization nodes will be expected for every tensor in - the graph (exlcuding those which will be fused). If a range is missing, - an error will occur. Please note that accuracy may be negatively affected - if there is a mismatch between which tensors TRT quantizes and which - tensors were trained with fake quantization. + use_calibration: this argument is ignored if precision_mode is not INT8. If + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. Returns: A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. @@ -117,13 +117,16 @@ def tensorrt_rewriter_config(rewriter_config=None, rewriter_config, rewriter_config_pb2.RewriterConfig): raise TypeError("rewriter_config should be a RewriterConfig proto.") + rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig() if rewriter_config is None: - rewriter_config = rewriter_config_pb2.RewriterConfig() # Layout optimizer may add Const nodes followed by Reshape nodes, thus we # need to run constant folding again. - rewriter_config.optimizers.extend(["constfold", "layout", "constfold"]) - rewriter_config.meta_optimizer_iterations = ( + rewriter_config_with_trt.optimizers.extend( + ["constfold", "layout", "constfold"]) + rewriter_config_with_trt.meta_optimizer_iterations = ( rewriter_config_pb2.RewriterConfig.ONE) + else: + rewriter_config_with_trt.CopyFrom(rewriter_config) if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): raise ValueError(("precision mode '{}' is not supported." @@ -131,7 +134,7 @@ def tensorrt_rewriter_config(rewriter_config=None, precision_mode, TrtPrecisionMode.supported_precision_modes)) - optimizer = rewriter_config.custom_optimizers.add() + optimizer = rewriter_config_with_trt.custom_optimizers.add() optimizer.name = "TensorRTOptimizer" optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size optimizer.parameter_map["max_batch_size"].i = max_batch_size @@ -149,7 +152,7 @@ def tensorrt_rewriter_config(rewriter_config=None, optimizer.parameter_map["cached_engine_batches"].list.i.extend( cached_engine_batch_sizes) optimizer.parameter_map["use_calibration"].b = use_calibration - return rewriter_config + return rewriter_config_with_trt def create_inference_graph(input_graph_def, @@ -161,7 +164,6 @@ def create_inference_graph(input_graph_def, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batch_sizes=None, - rewriter_config=None, use_calibration=True, input_saved_model_dir=None, input_saved_model_tags=None, @@ -194,17 +196,15 @@ def create_inference_graph(input_graph_def, use this list to determine the batch sizes of the cached engines, instead of making the decision on the fly. This is useful when we know the most common batch size(s) the application is going to generate. - rewriter_config: a RewriterConfig proto to append the TensorRTOptimizer to. - If None, it will create one with default settings. - use_calibration: this argument is ignored if precision_mode is not INT8. - if set to True, a calibration graph will be created to calibrate the - missing ranges. The calibration graph must be converted to an inference - graph using calib_graph_to_infer_graph() after running calibration. - if set to False, quantization nodes will be expected for every tensor in - the graph (exlcuding those which will be fused). If a range is missing, - an error will occur. Please note that accuracy may be negatively affected - if there is a mismatch between which tensors TRT quantizes and which - tensors were trained with fake quantization. + use_calibration: this argument is ignored if precision_mode is not INT8. If + set to True, a calibration graph will be created to calibrate the missing + ranges. The calibration graph must be converted to an inference graph + using calib_graph_to_infer_graph() after running calibration. if set to + False, quantization nodes will be expected for every tensor in the graph + (exlcuding those which will be fused). If a range is missing, an error + will occur. Please note that accuracy may be negatively affected if there + is a mismatch between which tensors TRT quantizes and which tensors were + trained with fake quantization. input_saved_model_dir: the directory to load the SavedModel which contains the input graph to transforms. Used only when input_graph_def is None. input_saved_model_tags: list of tags to load the SavedModel. @@ -212,8 +212,9 @@ def create_inference_graph(input_graph_def, returned GraphDef and save it to the specified directory. This option only works when the input graph is loaded from a SavedModel, i.e. when input_saved_model_dir is specified and input_graph_def is None. - session_config: the ConfigProto used to create a Session. If not specified, - a default ConfigProto will be used. + session_config: the ConfigProto used to create a Session. It's also used as + a template to create a TRT-enabled ConfigProto for conversion. If not + specified, a default ConfigProto will be used. Returns: A GraphDef transformed from input_graph_def (or the SavedModel graph def @@ -343,21 +344,30 @@ def create_inference_graph(input_graph_def, grappler_meta_graph_def.collection_def["train_op"].CopyFrom( output_collection) - # Create RewriterConfig. - rewriter_config = tensorrt_rewriter_config( + # Create TRT-enabled ConfigProto. + session_config_with_trt = config_pb2.ConfigProto() + session_config_with_trt.CopyFrom(session_config) + rewriter_config = None + if (session_config_with_trt.HasField("graph_options") and + session_config_with_trt.graph_options.HasField("rewrite_options")): + rewriter_config = session_config_with_trt.graph_options.rewrite_options + rewriter_config_with_trt = get_tensorrt_rewriter_config( rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, cached_engine_batch_sizes, use_calibration) + session_config_with_trt.graph_options.rewrite_options.CopyFrom( + rewriter_config_with_trt) # Run Grappler. transformed_graph_def = tf_optimizer.OptimizeGraph( - rewriter_config, grappler_meta_graph_def, graph_id=b"tf_graph") + session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph") # Optionally write the transformed graphdef as SavedModel. if output_saved_model_dir is not None: saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) with ops.Graph().as_default(): importer.import_graph_def(transformed_graph_def, name="") + # We don't use TRT here. with session.Session(config=session_config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py index 9f2eeac990dcacb547d336b68bc042016c3e6171..aa82f4207f5fa9c646cadbc4ca4fd7ab40c089ff 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -47,9 +47,9 @@ from tensorflow.python.tools import saved_model_utils class TrtConvertTest(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration python API.""" - def testTensorrtRewriterConfig(self): - """Test case for trt_convert.tensorrt_rewriter_config().""" - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + def testGetTensorrtRewriterConfig(self): + """Test case for trt_convert.get_tensorrt_rewriter_config().""" + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( rewriter_config=None, max_batch_size=128, max_workspace_size_bytes=1234, diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index cbff661f99df0e6f6d1a2b0f8806849e7e5ca454..b325d76edfabce25f165a6b23c5f39bb6ac84247 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -56,8 +56,9 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): strides=[1, 2, 2, 1], padding="SAME", name="conv") - bias = constant_op.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + bias = constant_op.constant([4., 1.5, 2., 3., 5., 7.], + name="bias", + dtype=dtype) added = nn.bias_add(conv, bias, name="bias_add") relu = nn.relu(added, "relu") identity = array_ops.identity(relu, "identity") @@ -73,11 +74,12 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", - # "relu", "identity", "max_pool"] - return ["my_trt_op_0"] + return { + "my_trt_op_0": [ + "weights", "conv", "bias", "bias_add", "relu", "identity", + "max_pool" + ] + } class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -92,7 +94,7 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): g = ops.Graph() with g.as_default(): inp = array_ops.placeholder( - dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + dtype=dtype, shape=input_dims, name=input_name) with g.device("/GPU:0"): conv_filter = constant_op.constant( [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], @@ -105,10 +107,10 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): padding="SAME", name="conv") c1 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + np.random.randn(12, 12, 6), dtype=dtype, name="c1") p = math_ops.mul(conv, c1, name="mul") c2 = constant_op.constant( - np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + np.random.randn(12, 12, 6), dtype=dtype, name="c2") q = math_ops.div(conv, c2, name="div") edge = self.trt_incompatible_op(q, name="incompatible") @@ -129,22 +131,21 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", - # "add", "sub1"]; - # - my_trt_op_1 should have ["weights","conv", "div"] - return ["my_trt_op_0", "my_trt_op_1"] + return { + "my_trt_op_0": [ + "add", "add1", "c1", "div1", "mul", "mul1", "sub", "sub1" + ], + "my_trt_op_1": ["c2", "conv", "div", "weights"] + } - def ShouldRunTest(self, run_params): - # TODO(aaroey): LayoutOptimizer adds Transpose(Const, Const) to the graph - # which breaks the conversion. We should fix it as: - # - Detect the invalid NodeDef earlier before adding them to segment - # - Let it able to change the RewriterConfig when calling - # create_inference_graph(). - # It will be good to add debugging feature for Grappler to print the graph - # after running each optimizer. - return False + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + return super( + SimpleMultiEnginesTest, self + ).GetConversionParams(run_params)._replace( + # Disable layout optimizer, since it'll add Transpose(Const, Const) to + # the graph and breaks the conversion check. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): @@ -199,7 +200,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): # can cause overflow. return ((run_params.precision_mode != "FP16") and not (trt_test.IsQuantizationMode(run_params.precision_mode) and - not run_params.use_calibration)) + not run_params.use_calibration)) class PartiallyConvertedTestB(PartiallyConvertedTestA): diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 7545bb9df20f295a8fdbc82b573cdb3407f8c5e4..6546ef64778e0ee3638b3aea08c61a9b32e0dc7b 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -41,6 +41,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): input_name = "input" input_matrix_rows = 4 input_matrix_columns = 144 + # Note that tf.nn.bias_add supports up to 5 dimensions. input_dims = [input_matrix_rows, input_matrix_columns] output_name = "output" g = ops.Graph() @@ -74,18 +75,18 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x5 = nn.bias_add(x5, b) x5 = gen_array_ops.reshape(x5, [4, -1]) - x6 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x6 = gen_array_ops.reshape(x, [4, 24, 6]) + b = self._ConstOp((6,)) x6 = nn.bias_add(x6, b, data_format="NHWC") x6 = gen_array_ops.reshape(x6, [4, -1]) - x7 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((4,)) + x7 = gen_array_ops.reshape(x, [4, 12, 4, 3]) + b = self._ConstOp((3,)) x7 = nn.bias_add(x7, b, data_format="NHWC") x7 = gen_array_ops.reshape(x7, [4, -1]) - x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) - b = self._ConstOp((2,)) + x8 = gen_array_ops.reshape(x, [4, 4, 3, 2, 6]) + b = self._ConstOp((6,)) x8 = nn.bias_add(x8, b, data_format="NHWC") x8 = gen_array_ops.reshape(x8, [4, -1]) @@ -94,13 +95,13 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): x9 = nn.bias_add(x9, b, data_format="NCHW") x9 = gen_array_ops.reshape(x9, [4, -1]) - x10 = gen_array_ops.reshape(x, [4, 12, 3, 4]) - b = self._ConstOp((12,)) + x10 = gen_array_ops.reshape(x, [4, 3, 4, 12]) + b = self._ConstOp((3,)) x10 = nn.bias_add(x10, b, data_format="NCHW") x10 = gen_array_ops.reshape(x10, [4, -1]) - x11 = gen_array_ops.reshape(x, [4, 12, 12]) - b = self._ConstOp((12,)) + x11 = gen_array_ops.reshape(x, [4, 6, 24]) + b = self._ConstOp((6,)) x11 = nn.bias_add(x11, b, data_format="NCHW") x11 = gen_array_ops.reshape(x11, [4, -1]) @@ -116,9 +117,14 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): def GetConversionParams(self, run_params): """Return a ConversionParams for test.""" - return super(BiasaddMatMulTest, - self).GetConversionParams(run_params)._replace( - max_batch_size=4, maximum_cached_engines=1) + conversion_params = super(BiasaddMatMulTest, + self).GetConversionParams(run_params) + return conversion_params._replace( + max_batch_size=4, + maximum_cached_engines=1, + # Disable layout optimizer, since it will convert BiasAdd with NHWC + # format to NCHW format under four dimentional input. + rewriter_config=trt_test.OptimizerDisabledRewriterConfig()) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py index 2586d936ef120c4548543ef82d2c7db3425d9c94..e7d6ec4ad395d38a06f97020f2f363009f2286c7 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py @@ -12,208 +12,279 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Script to test TF-TRT INT8 conversion without calibration on Mnist model.""" -import numpy as np -import os +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -import tensorflow as tf -from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.keras.datasets import mnist -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test -from tensorflow.python import estimator as tf_estimator +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import data +from tensorflow.python import keras from tensorflow.python.estimator.estimator import Estimator +from tensorflow.python.estimator.model_fn import EstimatorSpec +from tensorflow.python.estimator.model_fn import ModeKeys from tensorflow.python.estimator.run_config import RunConfig -from tensorflow.python.estimator.model_fn import ModeKeys, EstimatorSpec +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras.datasets import mnist +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import saver +from tensorflow.python.training.adam import AdamOptimizer +from tensorflow.python.training.checkpoint_management import latest_checkpoint +from tensorflow.python.training.training_util import get_global_step INPUT_NODE_NAME = 'input' OUTPUT_NODE_NAME = 'output' -def build_graph(x): - def quantize(x, r): - x = tf.fake_quant_with_min_max_args(x, -r, r) - return x - def dense_layer(x, num_inputs, num_outputs, quantization_range, name='dense'): - """Equivalent to tf.layers.dense but with a quantization range between - the MatMul and BiasAdd.""" - with tf.variable_scope(name) as scope: - kernel = tf.get_variable('kernel', shape=[num_inputs, num_outputs], - dtype=tf.float32, initializer=tf.keras.initializers.glorot_uniform()) - bias = tf.get_variable('bias', shape=[num_outputs,], - dtype=tf.float32, initializer=tf.keras.initializers.zeros()) - x = tf.matmul(x, kernel) - x = quantize(x, quantization_range) - x = tf.nn.bias_add(x, bias) +class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): + + def _BuildGraph(self, x): + + def _Quantize(x, r): + x = gen_array_ops.quantize_and_dequantize_v2(x, -r, r) + return x + + def _DenseLayer(x, num_inputs, num_outputs, quantization_range, name): + """Dense layer with quantized outputs. + + Args: + x: input to the dense layer + num_inputs: number of input columns of x + num_outputs: number of output columns + quantization_range: the min/max range for quantization + name: name of the variable scope + + Returns: + The output of the layer. + """ + with variable_scope.variable_scope(name): + kernel = variable_scope.get_variable( + 'kernel', + shape=[num_inputs, num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.glorot_uniform()) + bias = variable_scope.get_variable( + 'bias', + shape=[num_outputs], + dtype=dtypes.float32, + initializer=keras.initializers.zeros()) + x = math_ops.matmul(x, kernel) + x = _Quantize(x, quantization_range) + x = nn.bias_add(x, bias) + x = _Quantize(x, quantization_range) + return x + + x = _Quantize(x, 1) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=32, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Conv + Bias + Relu6 + x = layers.conv2d(x, filters=64, kernel_size=3, use_bias=True) + x = nn.relu6(x) + # Reduce + x = math_ops.reduce_mean(x, [1, 2]) + x = _Quantize(x, 6) + # FC1 + x = _DenseLayer(x, 64, 512, 6, name='dense') + x = nn.relu6(x) + # FC2 + x = _DenseLayer(x, 512, 10, 25, name='dense_1') + x = array_ops.identity(x, name=OUTPUT_NODE_NAME) return x - x = quantize(x, 1) - # Conv + Bias + Relu6 - x = tf.layers.conv2d(x, filters=32, kernel_size=3, use_bias=True) - x = tf.nn.relu6(x) - # Conv + Bias + Relu6 - x = tf.layers.conv2d(x, filters=64, kernel_size=3, use_bias=True) - x = tf.nn.relu6(x) - x = tf.reduce_mean(x, [1, 2]) - x = quantize(x, 6) - # FC1 - x = dense_layer(x, 64, 512, 6, name='dense') - x = quantize(x, 6) - x = tf.nn.relu6(x) - # FC2 - x = dense_layer(x, 512, 10, 25, name='dense_1') - x = quantize(x, 25) - x = tf.identity(x, name=OUTPUT_NODE_NAME) - return x - -def preprocess_fn(x, y): - x = tf.cast(x, tf.float32) - x = tf.expand_dims(x, axis=2) - x = 2.0 * (x / 255.0) - 1.0 - y = tf.cast(y, tf.int32) - return x, y - -def run(is_training, use_trt, batch_size, num_epochs, model_dir): - """Train or evaluate the model. - - Args: - is_training: Whether to train or evaluate the model. In training mode, - quantization will be simulated where the fake_quant_with_min_max_args - are placed. - use_trt: If true, use TRT INT8 mode for evaluation, which will perform real - quantization. Otherwise use native TensorFlow which will perform - simulated quantization. Ignored if is_training is True. - batch_size: Batch size. - num_epochs: How many epochs to train. Ignored if is_training is False. - model_dir: Where to save or load checkpoint. - """ - # Get dataset - train, test = mnist.load_data() - - def eval_input_fn(): - mnist_x, mnist_y = test - dataset = tf.data.Dataset.from_tensor_slices((mnist_x, mnist_y)) - dataset = dataset.apply(tf.data.experimental.map_and_batch( - map_func=preprocess_fn, - batch_size=batch_size, - num_parallel_calls=8)) - dataset = dataset.repeat(count=1) - iterator = dataset.make_one_shot_iterator() - features, labels = iterator.get_next() - return features, labels - - def train_input_fn(): - mnist_x, mnist_y = train - dataset = tf.data.Dataset.from_tensor_slices((mnist_x, mnist_y)) - dataset = dataset.shuffle(2*len(mnist_x)) - dataset = dataset.apply(tf.data.experimental.map_and_batch( - map_func=preprocess_fn, - batch_size=batch_size, - num_parallel_calls=8)) - dataset = dataset.repeat(count=num_epochs) - iterator = dataset.make_one_shot_iterator() - features, labels = iterator.get_next() - return features, labels - - def model_fn(features, labels, mode): - if is_training: - logits_out = build_graph(features) - else: - graph_def = get_graph_def(use_trt, batch_size, model_dir) - logits_out = tf.import_graph_def(graph_def, - input_map={INPUT_NODE_NAME: features}, - return_elements=[OUTPUT_NODE_NAME+':0'], - name='')[0] - loss = tf.losses.sparse_softmax_cross_entropy( - labels=labels, - logits=logits_out) - tf.summary.scalar('loss', loss) - classes_out = tf.argmax(logits_out, axis=1, name='classes_out') - accuracy = tf.metrics.accuracy( - labels=labels, - predictions=classes_out, - name='acc_op') - tf.summary.scalar('accuracy', accuracy[1]) - if mode == ModeKeys.EVAL: - return EstimatorSpec( - mode, - loss=loss, - eval_metric_ops={'accuracy': accuracy}) - elif mode == ModeKeys.TRAIN: - optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) - train_op = optimizer.minimize( - loss, - global_step=tf.train.get_global_step()) - return EstimatorSpec( - mode, - loss=loss, - train_op=train_op) - - tf_config = config_pb2.ConfigProto() - tf_config.gpu_options.allow_growth = True - estimator = Estimator( - model_fn=model_fn, - model_dir=None, - config=RunConfig(session_config=tf_config)) - if is_training: - estimator.train(train_input_fn) - results = estimator.evaluate(eval_input_fn) - print('accuracy:', results['accuracy']) - return results - -def get_graph_def(use_trt, batch_size, model_dir): - # Load graph and freeze - with tf.Graph().as_default() as graph: - with tf.Session() as sess: - x = tf.placeholder(shape=(None, 28, 28, 1), - dtype=tf.float32, - name=INPUT_NODE_NAME) - logits_out = build_graph(x) + def _GetGraphDef(self, use_trt, max_batch_size, model_dir): + """Get the frozen mnist GraphDef. + + Args: + use_trt: whether use TF-TRT to convert the graph. + max_batch_size: the max batch size to apply during TF-TRT conversion. + model_dir: the model directory to load the checkpoints. + + Returns: + The frozen mnist GraphDef. + """ + graph = ops.Graph() + with self.session(graph=graph) as sess: + with graph.device('/GPU:0'): + x = array_ops.placeholder( + shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME) + self._BuildGraph(x) # Load weights - saver = tf.train.Saver() - checkpoint_file = tf.train.latest_checkpoint(model_dir) - saver.restore(sess, checkpoint_file) + mnist_saver = saver.Saver() + checkpoint_file = latest_checkpoint(model_dir) + mnist_saver.restore(sess, checkpoint_file) # Freeze - graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph_def, - output_node_names=[OUTPUT_NODE_NAME] + graph_def = graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME]) + # Convert with TF-TRT + if use_trt: + logging.info('Number of nodes before TF-TRT conversion: %d', + len(graph_def.node)) + graph_def = trt_convert.create_inference_graph( + graph_def, + outputs=[OUTPUT_NODE_NAME], + max_batch_size=max_batch_size, + precision_mode='INT8', + max_workspace_size_bytes=4096 << 19, + minimum_segment_size=2, + use_calibration=False, ) - # Convert with TF-TRT - if use_trt: - print('nodes before:', len(graph_def.node)) - graph_def = create_inference_graph(graph_def, - outputs=[OUTPUT_NODE_NAME], - max_batch_size=batch_size, - precision_mode='int8', - max_workspace_size_bytes=4096 << 19, - minimum_segment_size=2, - use_calibration=False, - ) - print('tftrt total nodes:', len(graph_def.node)) - print('trt only nodes', - len([1 for n in graph_def.node if str(n.op)=='TRTEngineOp'])) - return graph_def + logging.info('Number of nodes after TF-TRT conversion: %d', + len(graph_def.node)) + num_engines = len( + [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp']) + self.assertEqual(1, num_engines) + return graph_def + def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir): + """Train or evaluate the model. -class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase): + Args: + is_training: whether to train or evaluate the model. In training mode, + quantization will be simulated where the quantize_and_dequantize_v2 are + placed. + use_trt: if true, use TRT INT8 mode for evaluation, which will perform + real quantization. Otherwise use native TensorFlow which will perform + simulated quantization. Ignored if is_training is True. + batch_size: batch size. + num_epochs: how many epochs to train. Ignored if is_training is False. + model_dir: where to save or load checkpoint. + + Returns: + The Estimator evaluation result. + """ + # Get dataset + train_data, test_data = mnist.load_data() + + def _PreprocessFn(x, y): + x = math_ops.cast(x, dtypes.float32) + x = array_ops.expand_dims(x, axis=2) + x = 2.0 * (x / 255.0) - 1.0 + y = math_ops.cast(y, dtypes.int32) + return x, y + + def _EvalInputFn(): + mnist_x, mnist_y = test_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=1) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + def _TrainInputFn(): + mnist_x, mnist_y = train_data + dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y)) + dataset = dataset.shuffle(2 * len(mnist_x)) + dataset = dataset.apply( + data.experimental.map_and_batch( + map_func=_PreprocessFn, + batch_size=batch_size, + num_parallel_calls=8)) + dataset = dataset.repeat(count=num_epochs) + iterator = dataset.make_one_shot_iterator() + features, labels = iterator.get_next() + return features, labels + + def _ModelFn(features, labels, mode): + if is_training: + logits_out = self._BuildGraph(features) + else: + graph_def = self._GetGraphDef(use_trt, batch_size, model_dir) + logits_out = importer.import_graph_def( + graph_def, + input_map={INPUT_NODE_NAME: features}, + return_elements=[OUTPUT_NODE_NAME + ':0'], + name='')[0] + + loss = losses.sparse_softmax_cross_entropy( + labels=labels, logits=logits_out) + summary.scalar('loss', loss) + + classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out') + accuracy = metrics.accuracy( + labels=labels, predictions=classes_out, name='acc_op') + summary.scalar('accuracy', accuracy[1]) + if mode == ModeKeys.EVAL: + return EstimatorSpec( + mode, loss=loss, eval_metric_ops={'accuracy': accuracy}) + elif mode == ModeKeys.TRAIN: + optimizer = AdamOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss, global_step=get_global_step()) + return EstimatorSpec(mode, loss=loss, train_op=train_op) + + config_proto = config_pb2.ConfigProto() + config_proto.gpu_options.allow_growth = True + estimator = Estimator( + model_fn=_ModelFn, + model_dir=model_dir if is_training else None, + config=RunConfig(session_config=config_proto)) + + if is_training: + estimator.train(_TrainInputFn) + results = estimator.evaluate(_EvalInputFn) + logging.info('accuracy: %s', str(results['accuracy'])) + return results + + # To generate the checkpoint, set a different model_dir and call self._Run() + # by setting is_training=True and num_epochs=1000, e.g.: + # model_dir = '/tmp/quantization_mnist' + # self._Run( + # is_training=True, + # use_trt=False, + # batch_size=128, + # num_epochs=100, + # model_dir=model_dir) def testEval(self): - model_dir = test.test_src_dir_path( - 'contrib/tensorrt/test/quantization_mnist_test_data') - acc_tf = run(is_training=False, + if not trt_convert.is_tensorrt_enabled(): + return + model_dir = test.test_src_dir_path('contrib/tensorrt/test/testdata') + + accuracy_tf_native = self._Run( + is_training=False, use_trt=False, batch_size=128, num_epochs=None, model_dir=model_dir)['accuracy'] - acc_tftrt = run(is_training=False, + logging.info('accuracy_tf_native: %f', accuracy_tf_native) + self.assertAllClose(accuracy_tf_native, 0.9662) + + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return + + accuracy_tf_trt = self._Run( + is_training=False, use_trt=True, batch_size=128, num_epochs=None, model_dir=model_dir)['accuracy'] - self.assertAllClose(acc_tf, 0.9717) - self.assertAllClose(acc_tftrt, 0.9744) + logging.info('accuracy_tf_trt: %f', accuracy_tf_trt) + self.assertAllClose(accuracy_tf_trt, 0.9677) + -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/contrib/tensorrt/test/quantization_test.py index 83295ce2bd3a7392f5837b4e10bbf73c81d91255..28353273edec4a2b0fd4300f87b0b1a4dbe37652 100644 --- a/tensorflow/contrib/tensorrt/test/quantization_test.py +++ b/tensorflow/contrib/tensorrt/test/quantization_test.py @@ -20,88 +20,86 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -def build_graph(input_name, input_dims, output_name, - add_quantization_nodes=False, dtype=dtypes.float32): - def quantize(x, r): +def _GetParams(add_quantization_nodes, dtype=dtypes.float32): + input_name = "input" + input_dims = [8, 8] + output_name = "output" + + def _Quantize(x, r): if add_quantization_nodes: x = gen_array_ops.fake_quant_with_min_max_vars(x, -r, r) return x + g = ops.Graph() with g.as_default(): x = array_ops.placeholder( dtype=dtype, shape=[None] + input_dims[1:], name=input_name) - - x = quantize(x, 10.0) + x = _Quantize(x, 10.0) x = x + 5 - x = quantize(x, 15.0) + x = _Quantize(x, 15.0) x = x - 5 - x = quantize(x, 10.0) + x = _Quantize(x, 10.0) x = x * 0.1 - x = quantize(x, 1.0) - w = constant_op.constant(np.ones((10, 1)), dtype=dtypes.float32) + x = _Quantize(x, 1.0) + w = constant_op.constant(np.ones((8, 1)), dtype=dtypes.float32) x = math_ops.matmul(x, w) - x = quantize(x, 10.0) + x = _Quantize(x, 10.0) x = array_ops.identity(x, name=output_name) - return g + + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + output_names=[output_name], + expected_output_dims=[(8, 1)]) + class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment with no quantization ranges.""" - input_name = "input" - input_dims = [128, 10] - output_name = "output" - g = build_graph(input_name, input_dims, output_name, - add_quantization_nodes=False) - return trt_test.TfTrtIntegrationTestParams( - gdef=g.as_graph_def(), - input_names=[input_name], - input_dims=[input_dims], - output_names=[output_name], - expected_output_dims=[(128, 1)]) + return _GetParams(add_quantization_nodes=False) def ShouldRunTest(self, run_params): - return (run_params.precision_mode == "INT8" and - not run_params.use_optimizer and - not run_params.dynamic_engine) + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Only test static engine mode, with or without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and + not run_params.use_optimizer and not run_params.dynamic_engine) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" if run_params.use_calibration: + # In static engine mode with calibration, it should build a calibration + # engine. return ["my_trt_op_0"] + # In static engine mode without calibration, the engine building will fail + # since no quantization ranges are set, which results in no TRT nodes. return [] + class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment with no quantization ranges.""" - input_name = "input" - input_dims = [128, 10] - output_name = "output" - g = build_graph(input_name, input_dims, output_name, - add_quantization_nodes=True) - return trt_test.TfTrtIntegrationTestParams( - gdef=g.as_graph_def(), - input_names=[input_name], - input_dims=[input_dims], - output_names=[output_name], - expected_output_dims=[(128, 1)]) + return _GetParams(add_quantization_nodes=True) def ShouldRunTest(self, run_params): - return (run_params.precision_mode == "INT8" and + if trt_convert.get_linked_tensorrt_version()[0] < 5: + return False + # Test static/dynamic engine with/without calibration. + return (trt_test.IsQuantizationMode(run_params.precision_mode) and not run_params.use_optimizer) def ExpectedEnginesToBuild(self, run_params): @@ -116,30 +114,23 @@ class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase): """The relative tolerance to compare floating point results.""" return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + class NonQuantizedPrecisionsWithRangesTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Create a graph containing single segment with no quantization ranges.""" - input_name = "input" - input_dims = [128, 10] - output_name = "output" - g = build_graph(input_name, input_dims, output_name, - add_quantization_nodes=True) - return trt_test.TfTrtIntegrationTestParams( - gdef=g.as_graph_def(), - input_names=[input_name], - input_dims=[input_dims], - output_names=[output_name], - expected_output_dims=[(128, 1)]) + return _GetParams(add_quantization_nodes=True) def ShouldRunTest(self, run_params): - return (run_params.precision_mode == "FP32" or - run_params.precision_mode == "FP16") + # Only test FP32/FP16 mode. + return not trt_test.IsQuantizationMode(run_params.precision_mode) def ExpectedEnginesToBuild(self, run_params): """Return the expected engines to build.""" + # The fake quant ops are not supported in FP32/FP16 mode, and will split the + # graph into three TRT segments. return ["my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3"] - + def ExpectedAbsoluteTolerance(self, run_params): """The absolute tolerance to compare floating point results.""" return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 @@ -148,5 +139,6 @@ class NonQuantizedPrecisionsWithRangesTest(trt_test.TfTrtIntegrationTestBase): """The relative tolerance to compare floating point results.""" return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01 + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/testdata/checkpoint b/tensorflow/contrib/tensorrt/test/testdata/checkpoint new file mode 100644 index 0000000000000000000000000000000000000000..a603e1aec91adab04fd9801ba05a2ee9adfbb6e8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/testdata/checkpoint @@ -0,0 +1,3 @@ +model_checkpoint_path: "model.ckpt-46900" +all_model_checkpoint_paths: "model.ckpt-0" +all_model_checkpoint_paths: "model.ckpt-46900" diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..88a998f184b275121e1e76eb51d2310da149f10a Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 differ diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index new file mode 100644 index 0000000000000000000000000000000000000000..537976571337508ab1798d33646c51d62a146ecc Binary files /dev/null and b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index differ diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 8804f2bc8f73e085982d10b8dba2a54d30eed608..80eb8552fd01531be76c228c10830c2fa33a2dec 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -66,6 +67,34 @@ class GraphState(object): INFERENCE = 2 +def OptimizerDisabledRewriterConfig(): + """Returns a RewriterConfig with all default Grappler optimizers disabled.""" + rewriter_config = rewriter_config_pb2.RewriterConfig() + + # Turn off all default Grappler optimizers. + off = rewriter_config_pb2.RewriterConfig.OFF + rewriter_config.layout_optimizer = off + rewriter_config.constant_folding = off + rewriter_config.shape_optimization = off + rewriter_config.remapping = off + rewriter_config.arithmetic_optimization = off + rewriter_config.dependency_optimization = off + rewriter_config.loop_optimization = off + rewriter_config.function_optimization = off + rewriter_config.debug_stripper = off + rewriter_config.disable_model_pruning = True + rewriter_config.scoped_allocator_optimization = off + rewriter_config.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) + rewriter_config.pin_to_host_optimization = off + rewriter_config.auto_parallel.enable = False + + # Run only once for each enabled optimizer. + rewriter_config.meta_optimizer_iterations = ( + rewriter_config_pb2.RewriterConfig.ONE) + return rewriter_config + + class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Class to test Tensorflow-TensorRT integration.""" @@ -203,11 +232,16 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + def _GetGPUOptions(self): + gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True + return gpu_options + def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: conversion_params = self.GetConversionParams(run_params) - rewriter_cfg = trt_convert.tensorrt_rewriter_config( + rewriter_cfg = trt_convert.get_tensorrt_rewriter_config( conversion_params.rewriter_config, conversion_params.max_batch_size, conversion_params.max_workspace_size_bytes, conversion_params.precision_mode, @@ -221,13 +255,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): else: graph_options = config_pb2.GraphOptions() - gpu_options = config_pb2.GPUOptions() - gpu_options.allow_growth = True - if trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options.per_process_gpu_memory_fraction = 0.50 - config = config_pb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) + gpu_options=self._GetGPUOptions(), graph_options=graph_options) return config def _ExpectTestValue(self, engine_name, method, expected_value): @@ -297,6 +326,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): params = self._GetParamsCached() conversion_params = self.GetConversionParams(run_params) logging.info(conversion_params) + + config_for_trt = config_pb2.ConfigProto(gpu_options=self._GetGPUOptions()) + if conversion_params.rewriter_config is not None: + config_for_trt.graph_options.rewrite_options.CopyFrom( + conversion_params.rewriter_config) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=params.input_names + params.output_names, @@ -307,8 +341,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): is_dynamic_op=conversion_params.is_dynamic_op, maximum_cached_engines=conversion_params.maximum_cached_engines, cached_engine_batch_sizes=conversion_params.cached_engine_batch_sizes, - rewriter_config=conversion_params.rewriter_config, - use_calibration=conversion_params.use_calibration) + use_calibration=conversion_params.use_calibration, + session_config=config_for_trt) def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: @@ -408,13 +442,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertEqual(run_params.dynamic_engine, is_dynamic_engine, node.name) self.assertEqual(node.attr["use_calibration"].b, - run_params.use_calibration, - node.name) + run_params.use_calibration, node.name) has_calibration_data = len(node.attr["calibration_data"].s) if (IsQuantizationMode(run_params.precision_mode) and - run_params.use_calibration and - graph_state == GraphState.INFERENCE): + run_params.use_calibration and graph_state == GraphState.INFERENCE): self.assertTrue(has_calibration_data, node.name) else: self.assertFalse(has_calibration_data, node.name) @@ -449,6 +481,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): # types. scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0 dims = params.input_dims[i] + # TODO(laigd): add debug options. E.g. we can set the input data to be + # continuous natural numbers: + # seq = np.arange(np.prod(dims)) + # seq.resize(dims) + # input_data.append(scale * seq.astype(dtype)) input_data.append((scale * np.random.random_sample(dims)).astype(dtype)) self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) @@ -541,7 +578,7 @@ def _AddTests(test_class): # graphdef using custom python wrapper class, which is not currently # supported yet. continue - if not dynamic_engine and use_calibration: + if use_calibration and not dynamic_engine: # Static engine with use_calibration=False will be static, so we want to # test that. If use_calibration=True, only dynamic op is supported. # TODO(aaroey): construction of static calibration engine is not @@ -553,8 +590,10 @@ def _AddTests(test_class): continue conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" - engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") - test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine" + calibration_type = "UseCalibration" if use_calibration else "NoCalibration" + test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode, + calibration_type) run_params = RunParams( use_optimizer=use_optimizer, precision_mode=precision_mode, diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index c230919168b937b26c68e141e15f0762ad70f3e6..ae7db35b47b326272dd2c7bc76e18047cec59865 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -106,6 +106,7 @@ py_test( ], srcs_version = "PY2AND3", tags = [ + "no_mac", "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. "notsan", # b/67865658 diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index af68aa03cf6583dc474eda6cda2e648fa1c3d08d..146ed9f27134e3e2a6c74627b6b78e53d65155f0 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -32,7 +32,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.estimator.export import export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index ffd838be40ed6267109fe36d95a681496fb2f964..7d780559f976516823611f3fe0ded056e4be088c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -30,7 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 90c7d8ac1a9c69216ece74af458cd750667f51ee..8f692d94da45bfaed6c72cf75d525346865aea34 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -38,7 +38,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index 43c5267e632e464d43ffcbcf6c551ff83d3c5767..aab330643862c1ccf073d2a0e34e1c475b1ec15f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -802,7 +802,7 @@ class InputStatisticsFromMiniBatch(object): array_ops.shape(times)[1] - 1, self._dtype)) # Co-locate updates with their variables to minimize race conditions when # updating statistics. - with ops.colocate_with(auxiliary_variables.max_time_seen): + with ops.device(auxiliary_variables.max_time_seen.device): # There is a race condition if this value is being updated from multiple # workers. However, it should eventually reach the correct value if the # last chunk is presented enough times. @@ -810,16 +810,16 @@ class InputStatisticsFromMiniBatch(object): auxiliary_variables.max_time_seen, gen_math_ops.maximum(auxiliary_variables.max_time_seen, math_ops.reduce_max(times))) - with ops.colocate_with(auxiliary_variables.chunk_count): + with ops.device(auxiliary_variables.chunk_count.device): chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count, array_ops.shape( times, out_type=dtypes.int64)[0]) - with ops.colocate_with(auxiliary_variables.inter_observation_duration_sum): + with ops.device(auxiliary_variables.inter_observation_duration_sum.device): inter_observation_duration_assign = state_ops.assign_add( auxiliary_variables.inter_observation_duration_sum, math_ops.reduce_sum(batch_inter_observation_duration)) - with ops.colocate_with(auxiliary_variables.example_count): + with ops.device(auxiliary_variables.example_count.device): example_count_assign = state_ops.assign_add( auxiliary_variables.example_count, array_ops.size(times, out_type=dtypes.int64)) @@ -829,11 +829,11 @@ class InputStatisticsFromMiniBatch(object): # the series are then members of fewer chunks. For series which are much # longer than the chunk size (the usual/expected case), this effect becomes # irrelevant. - with ops.colocate_with(auxiliary_variables.overall_feature_sum): + with ops.device(auxiliary_variables.overall_feature_sum.device): overall_feature_sum_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum, math_ops.reduce_sum(values, axis=[0, 1])) - with ops.colocate_with(auxiliary_variables.overall_feature_sum_of_squares): + with ops.device(auxiliary_variables.overall_feature_sum_of_squares.device): overall_feature_sum_of_squares_assign = state_ops.assign_add( auxiliary_variables.overall_feature_sum_of_squares, math_ops.reduce_sum(values**2, axis=[0, 1])) @@ -869,7 +869,7 @@ class InputStatisticsFromMiniBatch(object): state_ops.assign(statistics.series_start_moments.mean, mean), state_ops.assign(statistics.series_start_moments.variance, variance)) - with ops.colocate_with(statistics.start_time): + with ops.device(statistics.start_time.device): series_start_update = control_flow_ops.cond( # Update moments whenever we even match the lowest time seen so far, # to ensure that series start statistics are eventually updated to diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index edd97b2a4c131dbce0a5111dbac7d40eddea2bae..a8cd4287e0003de300b7114cf3f88d21d3239e6e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -27,7 +27,7 @@ from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 3c07a74ed8af9e3ab70408f9b43cb62b6bd4c7f2..125750e7639ad40c481472a93353e6fb7055be96 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -40,7 +40,10 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["state_space_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_mac", + "no_windows", # TODO: needs investigation on Windows + ], deps = [ ":state_space_model", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 67327d32000caea5db75f4d83e5743e8bde70a92..a0a9cb3f31a945a00eb3f6a5fd1402aab9a2df5f 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -246,6 +246,7 @@ py_library( "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", "python/tpu/session_support.py", + "python/tpu/tensor_tracer.py", "python/tpu/topology.py", "python/tpu/tpu.py", "python/tpu/tpu_feed.py", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 38d1c3049ef7185f2f9f448361029d066678cdae..541fbf33a302a4d850422885fdbbc438bd6b9b7b 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -94,13 +94,6 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( - name = "tf_op_stats_proto", - srcs = ["tf_op_stats.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "tpu_profiler_analysis_proto", srcs = ["tpu_profiler_analysis.proto"], diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto deleted file mode 100644 index 1e66801efd4b2a997ed85289b9b1690bb5d07737..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ /dev/null @@ -1,261 +0,0 @@ -// This proto describes the format of tensorflow operation level stats for -// profiling (in tensorboard) purpose. - -syntax = "proto2"; - -package tensorflow.tpu; - -// Result proto for OpMetrics. -message OpMetricsResult { - // True if this OP is executed on the device; False if it is executed on the - // host. - optional bool on_device = 1; - reserved 2; // was uint32 id. - // Name of this OP. - optional string name = 3; - // Rank of this OP. - optional uint64 rank = 4; - // The starting time in cycles of the last instance of this OP executed. - optional double last_starttime_in_cycles = 5; - // The ending time in cycles of the last instance of this OP executed. - optional double last_endtime_in_cycles = 6; - // If this OP (say A), is an immediate child of another OP (say B), this field - // stores the sum of duration in microseconds of A inside B. If A appears more - // than once in B, the duration of all A's appearances will be added together. - // This sum will be reset after the self-time of B is calculated so that it - // can be reused for a new parent OP. - optional double sum_of_duration_in_us_as_children = 7; - // Number of instances that this OP occurred. - optional uint64 occurrences = 8; - // Total time in microseconds spent in this OP (accumulated - // over all of its occurrences). - optional double total_time_in_us = 9; - // Total self time in microseconds spent in this OP - // (accumulated over all of its occurrences). - optional double total_self_time_in_us = 10; - // The total self time as a fraction of sum of all OP's - // total self time on the host. - optional double host_total_self_time_as_fraction_of_all_op_time = 11; - // Cumulative total self time in fraction on the host. - optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = - 12; - // The total self time as a fraction of sum of all OP's - // total self time on the device. - optional double device_total_self_time_as_fraction_of_all_op_time = 13; - // Cumulative total self time in fraction on the device. - optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = - 14; - // Total number of FLOPs incurred by this OP. - optional double total_flops = 15; - // Total number of bytes accessed by this OP. - optional double total_bytes_accessed = 16; - // Total time in microseconds that special hw unit 1 is occupied by this OP. - optional double unit1_occupancy_in_us = 17; - // Total time in microseconds that special hw unit 2 is occupied by this OP. - optional double unit2_occupancy_in_us = 18; - // Total memory stall time in microseconds. - optional double total_memory_stall_in_us = 19; -} - -// Result proto for OpMetricsDb. -message OpMetricsDbResult { - // A bunch of OpMetricsResults. - repeated OpMetricsResult metrics_db = 1; - // The total host infeed-enqueue duration in picoseconds. - optional uint64 total_host_infeed_enq_duration_ps = 2; - // The total of the difference between the start times of two - // consecutive infeed-enqueues (per host) in picoseconds. - optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3; - // The total device time in microseconds. - optional double total_device_time_in_us = 4; - // The total host time in microseconds. - optional double total_host_time_in_us = 5; -} - -// Result proto for StepInfo. -message StepInfoResult { - // The (micro) step number. - optional uint32 step_num = 1; - // The step duration in picoseconds. - optional uint64 duration_ps = 2; - // The infeed duration in picoseconds. - optional uint64 infeed_duration_ps = 3; - // The outfeed duration in picoseconds. - optional uint64 host_outfeed_ps = 8; - // The start time of this step in picoseconds. - optional uint64 begin_ps = 4; - // The waiting time within this step in picoseconds. - optional uint64 wait_duration_ps = 5; - // The unit b outfeed duration in picoseconds. - optional uint64 unit_b_outfeed_ps = 9; - // The time spent on cross-replica-sum in picoseconds. - optional uint64 crs_duration_ps = 6; - // Percentage of unit b time spent on infeed. - optional double unit_b_infeed_percent = 7; -} - -// Result proto for a sequence of steps. -message StepSequenceResult { - // A sequence of StepInfoResults. - repeated StepInfoResult step_sequence = 1; -} - -// Result proto for a StepDatabase. -message StepDatabaseResult { - // A map from core_id to StepSequenceResult. - map step_sequence_per_core = 1; -} - -// Result proto for looping-related metrics. -message LoopingResult { - // The total iteration time in nanoseconds. - optional double iteration_time_ns = 1; - // The total number of iterations. - optional int32 num_iterations = 2; - // The total computation time in nanoseconds. - optional double computation_time_ns = 3; - // The total number of computations. - optional int32 num_computations = 4; -} - -// Result proto for HloExtraInfo. -message HloExtraInfoResult { - // Category of the HLO op given by the compiler. - optional string category = 1; - // The long name of the HLO that includes the dimensions. - optional string long_name = 2; - // The per-TPU-core batch size inferred from this HLO. - optional int64 per_core_batch_size = 3; -} - -// Result proto for HloExtraInfoMap. -message HloExtraInfoMapResult { - // A map from HLO name to HloExtraInfo. - map hlo_extrainfo_map = 1; -} - -// Result proto for host-independent job information. -message HostIndependentJobInfoResult { - // The change-list number of this build. - optional int64 change_list = 1; - // The time of this build. - optional int64 build_time = 2; - // The target of this build. - optional string build_target = 3; -} - -// Result proto for host-dependent job information. -message HostDependentJobInfoResult { - // This ID of the host where the job was run on. - optional string host_id = 1; - // The command line used to run the job. - optional string command_line = 2; - // The start time of the job on this host. - optional int64 start_time = 3; -} - -// Result proto for RunEnvironment (the run environment of a profiling session). -message RunEnvironmentResult { - // Number of hosts used. - optional int32 host_count = 1; - // The type of TPU used. - optional string tpu_type = 2; - // The number of TPU cores used. - optional int32 tpu_core_count = 3; - // The per-TPU-core batch size. - optional int32 per_core_batch_size = 4; - // Host-independent job information. - optional HostIndependentJobInfoResult host_independent_job_info = 5; - // Host-dependent job information. - repeated HostDependentJobInfoResult host_dependent_job_info = 6; - // The number of replicas, corresponds to input parallelism. - // If there is no model parallelism, replica_count = tpu_core_count - optional int32 replica_count = 7; - // The number of cores used for a single replica, e.g. model parallelism. - // If there is no model parallelism, then num_cores_per_replica = 1 - optional int32 num_cores_per_replica = 8; -} - -// The types of host operations that are tracked. -enum HostOp { - // Invalid host op. - kINVALIDHostOp = 0; - // Each of host op type has two parts: - // (1) the stage where the op happens and (2) the op name. - // stage = Input Data Producer, op = Get Next Batch. - kInputDataProducerGetNextBatch = 1; - // stage = Input Data Producer, op = Session Run. - kInputDataProducerSessionRun = 2; - // stage = Input Data Producer, op = Forward Batch. - kInputDataProducerForwardBatch = 3; - // stage = Infeed Thread, op = Get Next Batch. - kInfeedThreadGetNextBatch = 4; - // stage = Infeed Thread, op = Session Run. - kInfeedThreadSessionRun = 5; - // stage = Infeed Thread, op = Forward Batch. - kInfeedThreadForwardBatch = 6; - // stage = Outfeed Thread, op = Get Next Batch. - kOutfeedThreadGetNextBatch = 7; - // stage = Outfeed Thread, op = Session Run. - kOutfeedThreadSessionRun = 8; - // stage = Outfeed Thread, op = Forward Batch. - kOutfeedThreadForwardBatch = 9; -} - -// Result proto for the host ops per TPU step. -message HostOpsPerTpuStep { - // Whether the data in this message is valid. - optional bool valid = 1 [default = false]; - // The current TPU step number. - optional uint32 tpu_step_num = 2; - // The beginning time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_begin_ps = 3; - // The ending time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_end_ps = 4; - // For each possible host operation, maps to the difference between the TPU - // step number that the host op targets and the current TPU step number. - // The key is HostOp, value is the step difference. - map step_diffs = 5; -} - -message HostOpsDetailsPerCore { - // Map from core id to HostOpsPerTpuStep. - map core_map = 1; -} - -message HostOpsDetailsPerHost { - // Map from hostname to a map from core id to HostOpsPerTpuStep. - map host_map = 1; -} - -// Result proto for the host ops for all TPU steps. -message HostOpsResult { - reserved 1; // (was repeated HostOpsPerTpuStep host_op_sequence) - // A sequence of records with one for each TPU step. Each record - // is a map from hostname to a map from core id to HostOpsPerTpuStep. - repeated HostOpsDetailsPerHost hostops_details = 2; -} - -// Result proto for TfStatsHelper. -message TfOpStats { - // The result for the TF-metric database. - optional OpMetricsDbResult tf_metrics_db = 1; - // The result for the HLO-metric database. - optional OpMetricsDbResult hlo_metrics_db = 2; - // The result for the step database. - optional StepDatabaseResult step_db = 3; - // The result for the looping-related metrics. - optional LoopingResult looping = 4; - // The result for the HloExtraInfoMap. - optional HloExtraInfoMapResult hlo_extrainfo_map = 5; - // Overall matrix unit utilization in percentage. - optional double matrix_unit_utilization_percent = 6; - // The run environment of this profiling session. - optional RunEnvironmentResult run_environment = 7; - // The result for the host operations. - optional HostOpsResult host_ops = 8; - // A map from core ID to name. - map core_id_to_name_map = 9; - // The result for hw unit b stats. - optional bytes unit_b_stats = 10; -} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index c2e3be03db0e4cca1a664f9e79aa9107384de312..aae1ab1d37a166303883e3a07a7a01efe2feab51 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -154,6 +154,14 @@ message OptimizationParameters { // updates; not present means no limits are applied. ClippingLimits gradient_clipping_limits = 7; + // Amount of weight decay to apply; see weight_decay_optimizers.py for + // details. Almost all optimizers are supported with this option (MDL Adagrad + // Light does not work, and SGD does not behave as expected if it is enabled). + // Although there is no check, users who want weight decay will probably also + // want to enable gradient accumulation as well so that the decay will happen + // once per minibatch. + float weight_decay_factor = 16; + // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). This feature is diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index c32bd5997c1493594d253650d42ae2215b2862a2..1b09ce173a64ba3f93ec019c8fd65dc4710f0fcf 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -80,6 +80,8 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): self._summary_writer = None self._global_step_tensor = None + self._last_checkpoint_step = None + def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run @@ -137,8 +139,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): last_step = session.run(self._global_step_tensor) - # Save the last checkpoint synchronously if needed. - if last_step != self._timer.last_triggered_step(): + if self._last_checkpoint_step != last_step: self._save(session, last_step, asynchronous=False) for l in self._listeners: @@ -164,15 +165,17 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) + + for l in self._listeners: + l.after_save(session, step) + end_time = time.time() logging.info("Checkpoint actual writing time: (%.3f sec)", end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - for l in self._listeners: - l.before_save(session, step) - if not asynchronous: + self._last_checkpoint_step = step _save_fn() return @@ -182,6 +185,7 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): logging.info("Saver thread still in progress, skipping checkpoint.") return + self._last_checkpoint_step = step self._save_thread = threading.Thread(target=_save_fn) self._save_thread.start() diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index abf9dc810fda97e5617f3be7fb85b6e782e3ca86..73753cd9181403d97b18f117a17e3e75e1f3b974 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -978,7 +978,7 @@ class TPUFunction(object): # When running on more than one core, concatenate outputs at the end # of processing. In backprop stage, the gradients will be - # calculdated according to the local inputs as gradient of + # calculated according to the local inputs as gradient of # cross-replica-concat being zero for any outputs other than those # from mlocal core so the loss calculation is identical. num_towers = self.model._tpu_assignment.num_towers @@ -1005,14 +1005,17 @@ class TPUFunction(object): for tensor in tpu_targets ] - if is_training or is_test: + if is_training or is_test: + with variable_scope.variable_scope( + 'metrics', reuse=variable_scope.AUTO_REUSE): self._cloned_model.compile( optimizer=_replicated_optimizer(self._cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, - metrics=metrics_module.clone_metrics(self.model.metrics), + metrics=metrics_module.clone_metrics( + self.model._compile_metrics), weighted_metrics=metrics_module.clone_metrics( - self.model.weighted_metrics), + self.model._compile_weighted_metrics), target_tensors=tpu_targets, ) @@ -1024,29 +1027,29 @@ class TPUFunction(object): # the Momentum optimizer) when _make_train_function is invoked. with keras_tpu_variables.replicated_variable_for_optimizer( self._tpu_assignment.num_towers): - self._cloned_model._make_train_function() + self._cloned_model._make_fit_function() else: - self._cloned_model._make_train_function() + self._cloned_model._make_fit_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model.train_function.outputs + for tensor in self._cloned_model._fit_function.outputs ] return [ - self._cloned_model.train_function.updates_op, + self._cloned_model._fit_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - self._cloned_model.train_function.outputs, + self._cloned_model._fit_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - self._cloned_model._make_test_function() + self._cloned_model._make_eval_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in self._cloned_model.test_function.outputs + for tensor in self._cloned_model._eval_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - self._cloned_model.test_function.outputs, + self._cloned_model._eval_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: @@ -1182,13 +1185,9 @@ class TPUFunction(object): # pipelined loop. return None, None - if (self.model.uses_learning_phase and - not isinstance(K.learning_phase(), int)): + if isinstance(inputs[-1], int): # Remove the learning_phase flag at the end. We currently hard code the # learning_phase in TPUFunction. - assert isinstance(inputs[-1], int), ( - 'Expect the final element be learning_phase flag. Got {}'.format( - inputs[-1])) inputs = inputs[:-1] if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or @@ -1376,6 +1375,9 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None + self._fit_function = None + self._eval_function = None + self._stateful_metric_functions = [] cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() @@ -1390,10 +1392,10 @@ class KerasTPUModel(models.Model): self.compile( self._cpu_model.optimizer, self._cpu_model.loss, - self._cpu_model.metrics, + self._cpu_model._compile_metrics, self._cpu_model.loss_weights, self._cpu_model.sample_weight_mode, - self._cpu_model.weighted_metrics, + self._cpu_model._compile_weighted_metrics, self._cpu_model.target_tensors, ) @@ -1647,7 +1649,7 @@ class KerasTPUModel(models.Model): self._make_train_function() sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + if not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] else: ins = inputs + targets + sample_weights @@ -1697,7 +1699,7 @@ class KerasTPUModel(models.Model): callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): # Reset stateful metrics - for m in self.stateful_metric_functions: + for m in self.metrics: m.reset_states() # Update callbacks callbacks.on_epoch_begin(epoch) @@ -1994,10 +1996,21 @@ class KerasTPUModel(models.Model): def optimizer(self, optimizer): self._optimizer = optimizer + @property + def metrics(self): + if self._tpu_model: + return self._tpu_model.metrics + return self._stateful_metric_functions + + @metrics.setter + def metrics(self, metrics): + self._stateful_metric_functions = metrics + def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( - self, model_fn_lib.ModeKeys.TRAIN, + self, + model_fn_lib.ModeKeys.TRAIN, tpu_assignment=self._tpu_assignment) return self.train_function @@ -2008,6 +2021,21 @@ class KerasTPUModel(models.Model): self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function + def _make_fit_function(self): + if not self._fit_function: + self._fit_function = TPUFunction( + self, + model_fn_lib.ModeKeys.TRAIN, + tpu_assignment=self._tpu_assignment) + + return self._fit_function + + def _make_eval_function(self): + if not self._eval_function: + self._eval_function = TPUFunction( + self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) + return self._eval_function + def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( @@ -2201,10 +2229,10 @@ def tpu_model(model, strategy=None): cpu_model.compile( _clone_optimizer(model.optimizer, optimizer_config), model.loss, - metrics_module.clone_metrics(model.metrics), + metrics_module.clone_metrics(model._compile_metrics), model.loss_weights, model.sample_weight_mode, - metrics_module.clone_metrics(model.weighted_metrics), + metrics_module.clone_metrics(model._compile_weighted_metrics), ) if model_weights: diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..70baea203cc6174bebc7d90646045efae5f2391d --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -0,0 +1,553 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================== +"""A utility to trace tensor values on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path +import re + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging + +_TRACER_LOG_PREFIX = ' [>>>TT>>>]' +_DEVICE_TYPE_TPU = 'tpu' +_DEVICE_TYPE_CPU = 'cpu' +_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' +_TRACE_MODE_NAN_INF = 'nan-inf' +_TRACE_MODE_PART_TENSOR = 'part-tensor' +_TRACE_MODE_PART_TENSOR_SIZE = 3 +_TRACE_MODE_FULL_TENSOR = 'full-tensor' +_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' +_RECORD_FILTERED_OUT = 'not-traced-filtered-out' +_RECORD_SCALAR = 'not-traced-scalar' +_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_RECORD_GET_TRACED = 'get-traced' +_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' +_MARKER_SECTION_END = '!!!!!!! section-end:' +_SECTION_NAME_CONFIG = 'configuration' +_SECTION_NAME_REASON = 'reason' +_SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_GRAPH = 'graph' +_FIELD_NAME_VERSION = 'version:' +_FIELD_NAME_DEVICE = 'device:' +_FIELD_NAME_TRACE_MODE = 'trace-mode:' +_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' +_FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' +_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' +_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") +_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') +_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') +_FLAG_NAME_ENABLE = 'enable' +_FLAG_NAME_TRACE_MODE = 'trace_mode' +_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' +_FLAG_NAME_OP_RANGE = 'op_range' +_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') +_OUTPUT_STREAM_ESCAPE = 'file://' +_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' + + +class TensorTracer(object): + """A software construct for tracing tensor values in a TF graph on TPU. + + This utility is disabled by default. It can be enabled by setting + the TENSOR_TRACER_FLAGS env variable as: + export TENSOR_TRACER_FLAGS="--enable=1" + If it is enabled, it will trace the output tensor values of + selected Ops in the graph. It has two outputs: (1) the traces and (2) + a report. The traces are dumped to a specified local file on the TPU + host. The report is printed to the log.info of the TPU job. + By passing options via the env variable, users can change: + (1) the trace mode (e.g., detecting NaN/Inf, printing partial or + full tensor values) + (2) which Ops to be traced (via op.name or op.type) + (3) output trace file path. + """ + + @staticmethod + def _match_next_flag(flags, pos): + """Returns the match for the next TensorTracer flag.""" + + match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_NO_QUOTE_PAT.match(flags, pos) + return match + + @staticmethod + def print_flag_values(): + """Prints all TensorTracer flags passed via environment variables.""" + + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR + result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, + tensor_tracer_flags) + result += 'Individual flag value:\n' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + flag_value = match.group(2) + result += ' %s: %s\n'%(flag_name, flag_value) + pos = match.end() + result += '\n' + return result + + @staticmethod + def get_flag_value(wanted_flag_name): + """Returns the value of a TensorTracer flags.""" + + tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return '' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + return '' + flag_name = match.group(1) + flag_value = match.group(2) + if flag_name == wanted_flag_name: + return flag_value + pos = match.end() + return '' + + @staticmethod + def is_enabled(): + """Returns True if TensorTracer is enabled.""" + + flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def use_test_undeclared_outputs_dir(): + """Decides the output directory of the trace file. + + Args: + None. + + Returns: + True if the output trace file should be written to the + test-undeclared-outputs-directory defined via an + env variable. + """ + + flag_value = TensorTracer.get_flag_value( + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def check_device_type(device_type): + """Checks if the given device type is valid.""" + + if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: + raise ValueError('Invalid device_type "%s"'%device_type) + + @staticmethod + def check_trace_mode(trace_mode): + """Checks if the given trace mode is valid.""" + + valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, + _TRACE_MODE_FULL_TENSOR] + if trace_mode not in valid_trace_modes: + raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' + 'Valid trace modes are: %s'%(trace_mode, + valid_trace_modes)) + + @staticmethod + def should_trace(device_type, op): + """Returns True if the given Op should be traced.""" + + if device_type != _DEVICE_TYPE_TPU: + raise ValueError('Non TPU device type is not supported') + if control_flow_util.IsInCond(op): + return False + if op.type in ['Reshape', 'ArgMin', 'ArgMax']: + return False + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR in op.node_def.attr + # pylint: enable=protected-access + + @staticmethod + def reason(op_idx, details): + """Returns why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) + + @staticmethod + def topological_sort(g): + """Performs topological sort on the given graph. + + Args: + g: the graph. + + Returns: + A pair where the first element indicates if the topological + sort succeeded (True if there is no cycle found; False if a + cycle is found) and the second element is either the sorted + list of nodes or the cycle of nodes found. + """ + + def visit(op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops): + """Recursively visits all Ops in a graph. + + Args: + op: the current Op being visited. + cycle: a cycle of Ops found. + permanently_marked_ops: the set of Ops that were already visited. + temporarily_marked_ops: the set of Ops that we have visited during + the current descent. + sorted_ops: the list of Ops sorted in topological order. + """ + + if cycle: + return + if op in permanently_marked_ops: + return + if op in temporarily_marked_ops: + cycle = temporarily_marked_ops + return + temporarily_marked_ops.add(op) + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + for consumer_op in out_tensor.consumers(): + visit(consumer_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + # pylint: disable=protected-access + for ctrl_output_op in op._control_outputs: + # pylint: enable=protected-access + visit(ctrl_output_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + temporarily_marked_ops.remove(op) + permanently_marked_ops.add(op) + sorted_ops.insert(0, op) + + graph_cycle = set([]) + sorted_ops = [] + permanently_marked_ops = set([]) + temporarily_marked_ops = set([]) + unsorted_ops = g.get_operations() + for op in unsorted_ops: + visit(op, graph_cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + if graph_cycle: + return (False, graph_cycle) + else: + assert len(unsorted_ops) == len(sorted_ops) + return (True, sorted_ops) + + def __init__(self): + """Initializes a TensorTracer. + + Sets the various member fields from the flags (if given) or the defaults. + """ + self._version = 'use-outside-compilation' + self._device_type = None + self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) + if not self._trace_mode: + self._trace_mode = _TRACE_MODE_NAN_INF + TensorTracer.check_trace_mode(self._trace_mode) + self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE + self._instrument_records = {} + interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) + self._selected_ops = interesting_ops.split() + self._set_trace_file_path() + self._set_op_range() + self._num_replicas = None + self._replica_id = None + + def _add_replica_id_to_graph(self, num_replicas, result_tensor): + """Adds nodes for computing the replica ID to the graph.""" + + if not num_replicas: + self._replica_id = 'unknown' + return result_tensor + + self._num_replicas = num_replicas + + with ops.control_dependencies(None): + # Uses None as dependency to run outside of TPU graph rewrites. + self._replica_id = tpu_ops.tpu_replicated_input( + list(range(self._num_replicas)), + name='tt_replica_id') + use_replica_id = array_ops.identity(self._replica_id).op + with ops.control_dependencies([use_replica_id]): + # Adds a control dependency from the result_tensor to + # the replica_id to ensure that replica_id will be added to the graph. + return array_ops.identity(result_tensor) + + def _set_trace_file_path(self): + """Sets the path of the output trace file.""" + + self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) + if not self._trace_file_path: + raise ValueError('--%s is not set in the environment variable %s' + %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) + elif TensorTracer.use_test_undeclared_outputs_dir(): + if os.path.isabs(self._trace_file_path): + raise ValueError('If use_test_undeclared_outputs_dir is set,' + 'trace_file_path cannot be an absolute path (%s)' + %self._trace_file_path) + outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) + self._trace_file_path = os.path.join(outputs_dir, + self._trace_file_path) + + def _set_op_range(self): + """Sets the index range of the Ops that we will consider tracing.""" + + op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) + if not op_range: + self._op_range = (-1, -1) # this means including all ops. + return + match = _OP_RANGE_PAT.match(op_range) + if not match: + self._op_range = (-1, -1) # this means including all ops. + return + self._op_range = (int(match.group(1)), int(match.group(2))) + + def _inside_op_range(self, idx): + """Return True if the given index is inside the selected range.""" + + if idx < self._op_range[0]: + return False + return self._op_range[1] < 0 or idx <= self._op_range[1] + + def _write_report(self, content): + """Writes the given content to the report.""" + + logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + + def _is_selected_op(self, op_name): + """Returns True if the Op with op_name is selected to be traced.""" + + if not self._selected_ops: + return True + if op_name in self._selected_ops: + return True + return False + + def _write_config_section(self): + """Writes the config section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) + self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) + self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) + self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) + self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) + + def _write_reason_section(self): + """Writes the reason section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) + for key in sorted(self._instrument_records): + self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) + + def _write_op_list_section(self, op_list): + """Writes the Op-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) + for i in range(0, len(op_list)): + self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + + def _write_graph_section(self, succeed, sorted_or_cycle): + """Writes the graph section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) + self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, + succeed)) + l = list(sorted_or_cycle) + for i in range(0, len(l)): + self._write_report('%d "%s"\n'%(i, l[i].name)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) + + def _make_tensor_trace_fun(self, op_name, output_idx): + """Makes the tensor tracing function called by outside compilation. + + Args: + op_name: the name of the Op that outputs the tensor to be traced. + output_idx: which output of the Op it is (0 means the first output). + + Returns: + A function to be passed as the first argument to outside compilation. + + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): + """Prints a tensor value to a file. + + Args: + op_name: the name of the Op that outputs the tensor to be printed. + output_idx: which output of the Op it is (0 means the first output). + num_elements: number of elements to print. + tensor: the tensor needs to be returned. + output_tensor: the tensor needs to be printed. + + Returns: + The same tensor passed via the "tensor" argument. + """ + msg = '"%s:%d" '%(op_name, output_idx) + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), + ' @', self._replica_id, + '\n', output_tensor, + summarize=num_elements, + output_stream=output_stream) + with ops.control_dependencies([print_op]): + return array_ops.identity(tensor).op + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + # Since host can't handle bf16, always convert tensor to f32. + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_any( + gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), + gen_math_ops.is_inf(tensor))) + else: + output_tensor = constant_op.constant(0) + return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + + def _show_global_step(tensor): + """Trace function for printing the global step count.""" + + return _print_tensor(op_name, output_idx, 1, tensor, tensor) + + def _show_part_tensor(tensor): + """Trace function for printing part of the tensor.""" + + return _print_tensor(op_name, output_idx, self._part_tensor_size, + tensor, tensor) + + def _show_full_tensor(tensor): + """Trace function for printing the entire tensor.""" + + return _print_tensor(op_name, output_idx, -1, tensor, tensor) + + if op_name == _GLOBAL_STEP_OP_NAME: + return _show_global_step + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return _show_part_tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return _show_full_tensor + + raise RuntimeError('Tensor trace fun for %s is not yet implemented' + %self._trace_mode) + + def trace_tpu(self, graph, result_tensor, num_replicas=None): + """Traces the tensors generated by TPU Ops in a TF graph. + + Args: + graph: the graph of Ops. + result_tensor: a result tensor of evaluating the graph. + num_replicas: number of replicas used on the TPU. + + Returns: + A tuple (result_tensor_copy, tracing_ops), where: + result_tensor_copy: an exact copy of result_tensor + tracing_ops: a list of tracing ops. If this list + is non empty, the caller of this function + should pose control dependencies upon these + Ops so that they will be executed when the + graph is evaluated. + """ + + self._device_type = _DEVICE_TYPE_TPU + TensorTracer.check_device_type(self._device_type) + result_tensor_copy = self._add_replica_id_to_graph(num_replicas, + result_tensor) + self._write_config_section() + tracing_ops = [] + operations = graph.get_operations() + self._write_op_list_section(operations) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + for op_id, op in enumerate(operations): + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_OUTSIDE_OP_RANGE) + continue + if not TensorTracer.should_trace(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_SHOULD_NOT_TRACE) + continue + if not self._is_selected_op(op.name): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_FILTERED_OUT) + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_DYNAMIC_SHAPE) + continue # cannot trace tensors with dynamic shape. + rank = len(out_tensor.shape) + if rank < 1: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_SCALAR) + continue # cannot trace scalar. + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_GET_TRACED) + consumers = out_tensor.consumers() + trace_op = tpu.outside_compilation( + self._make_tensor_trace_fun(op.name, i), out_tensor) + if consumers: + for consumer_op in consumers: + # pylint: disable=protected-access + consumer_op._add_control_input(trace_op) + # pylint: enable=protected-access + else: + # if there is no consumer, we will add the control dependence later + # when we add the control dependency to the output operations. + tracing_ops.append(trace_op) + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + + return (result_tensor_copy, tracing_ops) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index b6bb5c6e56c74003ed8ceafe9246fb6a05d928dd..6ae718cc2c9716587849aeee8abcd0a1de82a9ae 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -189,12 +189,13 @@ class Topology(object): def cpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the CPU device attached to a logical core.""" return _tpu_host_device_name( - job, self._topology_tasks[device_coordinates]) + job, self._topology_tasks[tuple(device_coordinates)]) def tpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, self._topology_tasks[device_coordinates], - self._topology_devices[device_coordinates]) + return _tpu_device_name(job, + self._topology_tasks[tuple(device_coordinates)], + self._topology_devices[tuple(device_coordinates)]) @property def num_tasks(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index e3e791faacb9b3c1fedbd83d3740e35351e38abb..a02361241cec5d16c4b05406c8b53bfd58156f56 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -1001,8 +1001,8 @@ def rewrite(computation, `rewrite` is a list of tensors corresponding to the tensors from the output of `computation`. - All `Operation`s returned from `computation` will be executed when - evaluating any of the returned output tensors. + All `Operation`s constructed during `computation` will be executed when + evaluating any of the returned output tensors, not just the ones returned. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index da6bdf67d686fba09d66386de982b57aa28d4dd4..672462447944b777375331d49727c4d5366cf295 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -41,7 +41,7 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = { class TPUContext(object): - """The context of current input_fn invocation.""" + """A context that holds the current configuration of the TPU computation.""" def __init__(self, internal_ctx, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 555ad0f1fdbe36f078c7d2fdcc67571f28c8b723..932367f4dd546c7867ea75eba1ae36813c9080da 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support @@ -108,6 +109,15 @@ ops.register_proto_function( from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access +def _is_iterable(obj): + """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" + try: + iter(obj) + return True + except TypeError: + return False + + def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: @@ -288,9 +298,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) - training_hooks = list(training_hooks or []) - evaluation_hooks = list(evaluation_hooks or []) - prediction_hooks = list(prediction_hooks or []) + training_hooks = tuple(training_hooks or []) + evaluation_hooks = tuple(evaluation_hooks or []) + prediction_hooks = tuple(prediction_hooks or []) for hook in training_hooks + evaluation_hooks + prediction_hooks: if not isinstance(hook, session_run_hook.SessionRunHook): @@ -325,7 +335,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - hooks = list(hooks or []) + hooks = tuple(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -1317,9 +1327,15 @@ class _ModelFnWrapper(object): captured_training_hooks.capture(estimator_spec.training_hooks) + tracing_ops = [] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, + self._ctx.num_replicas) + # We must run train_op to update the variables prior to running the # outfeed. - with ops.control_dependencies([train_op]): + with ops.control_dependencies([train_op]+tracing_ops): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): @@ -2250,8 +2266,7 @@ class TPUEstimator(estimator_lib.Estimator): # Only fetching `tpu_tensors_on_cpu` does not trigger # TPU computation and blocks, so we add the control dependency here. control_inputs = ( - tpu_tensors_on_cpu if isinstance(tpu_tensors_on_cpu, - (list, tuple)) else + tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else (tpu_tensors_on_cpu,)) with ops.control_dependencies(control_inputs): new_tensors.append(array_ops.identity(t)) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index e75a09492ec12b95bad32b221a8e78a1b79f3a6b..cf36103277de2e3b055ae89c66b198fb55bb4522 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -92,8 +91,7 @@ class InfeedQueue(object): else: raise ValueError( "number of tuple elements cannot be inferred from InfeedQueue " - "constructor" - ) + "constructor") if number_of_tuple_elements <= 0: raise ValueError("number_of_tuple_elements %d must be > 0" % number_of_tuple_elements) @@ -293,9 +291,8 @@ class InfeedQueue(object): self.number_of_tuple_elements """ if len(input_tensors) != self.number_of_tuple_elements: - raise ValueError( - "input_tensors is %s, but should be a list of %d Tensors", ( - str(input_tensors), self.number_of_tuple_elements)) + raise ValueError("input_tensors is %s, but should be a list of %d Tensors" + % (str(input_tensors), self.number_of_tuple_elements)) self.set_tuple_shapes([t.shape for t in input_tensors]) self.set_tuple_types([t.dtype for t in input_tensors]) @@ -451,8 +448,8 @@ class InfeedQueue(object): for i in xrange(1, self.number_of_tuple_elements): if devices[0] != devices[i]: raise ValueError( - "input devices for shard %d are %s, but should all be the same", - index, str(devices)) + "input devices for shard %d are %s, but should all be the same" % + (index, str(devices))) with ops.colocate_with(inputs[0]): return tpu_ops.infeed_enqueue_tuple( inputs=inputs, @@ -792,18 +789,14 @@ class _PartitionedInfeedQueue(InfeedQueue): Args: tensor: Input tensor for partitioning. - dims: A list of integer describes how to partition the input tensor. + dims: 1-D np.array of the list of integer describes how to partition the + input tensor. Raises: ValueError: If the tensor can't be partitioned by dims or the num_cores_per_replica doesn't match the number of partitions(dims.prod()). """ - if dims is None: - return - - dims = np.array(dims) - if (dims < 1).any(): raise ValueError("All input partition dims must be >= 1.") @@ -823,11 +816,6 @@ class _PartitionedInfeedQueue(InfeedQueue): "partition dims = {}).".format(tensor.shape.as_list(), dims)) tensor.shape.assert_is_fully_defined() - if (np.array(tensor.shape.as_list()) % dims != 0).any(): - raise ValueError( - "All input partition dims must divide exactly into the `Tensor` " - "shape (tensor shape = {}, input partition dims = {}).".format( - tensor.shape.as_list(), dims)) def _partition_or_replicate_on_host(self, tensor, dims): """Partitions or replicates the input tensor. @@ -840,16 +828,33 @@ class _PartitionedInfeedQueue(InfeedQueue): Returns: An iterator of `Tensor`s or a list of partioned tensors. """ - self._check_input_partition_dims(tensor, dims) if dims is None: return itertools.repeat(tensor) - else: - output = [tensor] - for axis, dim in enumerate(dims): - if dim > 1: - output = [array_ops.split(x, dim, axis=axis) for x in output] - output = nest.flatten(output) - return output + dims = np.array(dims) + self._check_input_partition_dims(tensor, dims) + output = [tensor] + divds, remainders = np.divmod(np.array(tensor.shape.as_list()), dims) + for axis, (divd, remainder, dim) in enumerate( + np.dstack((divds, remainders, dims))[0]): + if dim <= 1: + continue + if remainder > 0: + # For each dimension, when it cannot be evenly partitioned, XLA assumes + # the size of last parts are smaller by 1. E.g. 2D tensor with shape + # (5, 14) and dims are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] + # => [[(3, 3), (3, 3), (2, 3), (2, 3)], + # [(2, 3), (2, 3), (2, 2), (2, 2)]] + output = [ + array_ops.split( + x, + num_or_size_splits=[divd + 1] * remainder + + [divd] * (dim - remainder), + axis=axis) for x in output + ] + else: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. @@ -866,13 +871,9 @@ class _PartitionedInfeedQueue(InfeedQueue): elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: - tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, - tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( - dtype=np.dtype(tensor.dtype.as_numpy_dtype), - shape_tuple=tile_shape), tile_assignment=tile_assignment) def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py index 338acef63f244613cbd14a2da04c7ec4d811a0af..acc5a049aa87649e4f8bf3a00be605616ea7b630 100644 --- a/tensorflow/contrib/util/__init__.py +++ b/tensorflow/contrib/util/__init__.py @@ -15,8 +15,6 @@ """Utilities for dealing with Tensors. -See [Contrib Util](https://tensorflow.org/api_guides/python/contrib.util) guide. - @@constant_value @@make_tensor_proto @@make_ndarray diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index f7c979e86320d59ad033e2b8d7fcdff89ce0d133..9db80f6b5736d849d88e1e41ea467a5ff11844f5 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -1028,7 +1027,10 @@ Status RdmaTensorResponse::PrepareRecvTensor( return errors::Aborted( "RecvTensor expects a different device incarnation: ", parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), - ". Your worker job was probably restarted. Check your " + ". Your worker job (\"", + channel_->adapter_->worker_env_->session_mgr->LegacySession() + ->worker_name, + "\") was probably restarted. Check your " "worker job for the reason why it was restarted."); } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index afe4c46c8efc59da3da07777ee1fd38be015753d..2a8c2718edd7faa844d2efb7e7ea007db48d846b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -300,6 +300,7 @@ filegroup( "platform/env_time.h", "platform/logging.h", "platform/macros.h", + "platform/platform_strings.h", "platform/types.h", ], visibility = ["//visibility:private"], @@ -383,6 +384,7 @@ cc_library( ":lib_platform", ":platform_base", "//tensorflow/core/platform/default/build_config:port", + "@com_google_absl//absl/base", "@snappy", ], ) @@ -518,6 +520,19 @@ cc_library( ], ) +cc_library( + name = "platform_strings", + srcs = tf_platform_srcs([ + "platform/platform_strings.cc", + "platform/platform_strings_computed.h", + ]), + hdrs = [ + "platform/platform_strings.h", + ], + visibility = ["//tensorflow/core:__subpackages__"], + deps = [":lib"], +) + filegroup( name = "platform_other_hdrs", srcs = [ @@ -1037,6 +1052,7 @@ tf_gen_op_libs( "batch_ops", "bitwise_ops", "boosted_trees_ops", + "tensor_forest_ops", "candidate_sampling_ops", "checkpoint_ops", "collective_ops", @@ -1057,6 +1073,7 @@ tf_gen_op_libs( "logging_ops", "manip_ops", "math_ops", + "mkl_nn_ops", "nccl_ops", "nn_ops", "no_op", @@ -1185,6 +1202,7 @@ cc_library( ":batch_ops_op_lib", ":bitwise_ops_op_lib", ":boosted_trees_ops_op_lib", + ":tensor_forest_ops_op_lib", ":candidate_sampling_ops_op_lib", ":checkpoint_ops_op_lib", ":collective_ops_op_lib", @@ -1229,7 +1247,7 @@ cc_library( ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + tf_additional_cloud_op_deps(), + ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1285,7 +1303,9 @@ cc_library( ":framework", ":lib", ":nn_ops_op_lib", - ], + ] + if_mkl([ + ":mkl_nn_ops_op_lib", + ]), alwayslink = 1, ) @@ -1336,6 +1356,7 @@ cc_library( "//tensorflow/core/kernels:batch_kernels", "//tensorflow/core/kernels:bincount_op", "//tensorflow/core/kernels:boosted_trees_ops", + "//tensorflow/core/kernels:tensor_forest_ops", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:checkpoint_ops", "//tensorflow/core/kernels:collective_ops", @@ -1667,6 +1688,7 @@ cc_library( cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1771,6 +1793,7 @@ cc_library( deps = [ ":protos_all_cc_impl", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_set", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", @@ -1795,6 +1818,7 @@ cc_library( deps = [ ":protos_all_cc_impl", "//third_party/eigen3", + "@com_google_absl//absl/container:flat_hash_set", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", "@protobuf_archive//:protobuf", @@ -2168,6 +2192,7 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", + ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2635,6 +2660,7 @@ tf_cuda_library( ":stats_calculator_portable", ":version_lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", "//third_party/eigen3", @@ -2811,7 +2837,6 @@ tf_cuda_library( ":functional_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:required", - ":core_cpu_impl", ]), alwayslink = 1, ) @@ -3048,7 +3073,9 @@ tf_cuda_library( ], copts = tf_copts(), cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()), - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow:internal", + ], deps = [ ":core_cpu_internal", ":lib", @@ -3402,6 +3429,16 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_strings_test", + size = "small", + srcs = ["platform/platform_strings_test.cc"], + deps = [ + ":lib", + ":platform_strings", + ], +) + tf_cc_test( name = "platform_env_test", size = "small", @@ -4080,6 +4117,7 @@ tf_cc_test( "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:topk_op", "//third_party/eigen3", ], ) @@ -4852,6 +4890,7 @@ transitive_hdrs( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:platform_strings", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor", ], diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 6f9885691595368ab50cfe660b1b5c75673063cf..d38a8424eb13009fbf84d7511fb1325085d8b809 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -182,11 +182,14 @@ void TestDeprecationVersionSetCorrectly( for (const auto& name_and_api_def : api_defs_map) { const auto& name = name_and_api_def.first; const auto& api_def = name_and_api_def.second; - ASSERT_TRUE(api_def.deprecation_version() == 0 || - api_def.deprecation_message().empty()) - << "ApiDef that includes deprecation_version > 0 must also specify " - << "a deprecation_message. Op " << name - << " has deprecation_version > 0 but deprecation_message is not set."; + if (api_def.deprecation_version() != 0) { + ASSERT_TRUE(api_def.deprecation_version() > 0) + << "Found ApiDef with negative deprecation_version"; + ASSERT_FALSE(api_def.deprecation_message().empty()) + << "ApiDef that includes deprecation_version > 0 must also specify " + << "a deprecation_message. Op " << name + << " has deprecation_version > 0 but deprecation_message is not set."; + } } } } // namespace diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt index cdaeb5091c7b407addec2811bbf0cb79e61db2d2..bfaf3d2ea5912bf5fde34a91ec51ad42f66b6adb 100644 --- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -4,7 +4,7 @@ op { in_arg { name: "float_values" description: <