diff --git a/WORKSPACE b/WORKSPACE index 17961829a605c2d1f2d2ba86a7c30c47618c139b..0c7bc085b512b084b9470abe17326d7c119aa327 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -14,6 +14,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +http_archive( + name = "base_images_docker", + sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", + strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", + urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], +) + +http_archive( + name = "bazel_toolchains", + sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", + strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", + ], +) + +http_archive( + name = "io_bazel_rules_docker", + sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", + strip_prefix = "rules_docker-0.5.1", + urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +) + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") + +remote_config_workspace() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. @@ -79,3 +106,4 @@ new_http_archive( "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) + diff --git a/configure.py b/configure.py index 2eeeceb3399c79775ce62b9569e940e469141a17..234561d94a46f57c4de5ca487360e2d5a3dfdb2f 100644 --- a/configure.py +++ b/configure.py @@ -43,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -1555,6 +1555,9 @@ def main(): check_bazel_version('0.15.0') reset_tf_configure_bazelrc() + # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later + write_to_bazelrc('import %workspace%/tools/bazel.rc') + cleanup_makefile() setup_python(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 11b42f349df89c605f4ce1130033a85c920258c9..859dc3b8d77be66e0f51e15d86188399273af23f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -352,6 +352,7 @@ package_group( "//tensorflow/...", "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", + "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", ], ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 16f633643d4726f6e2d1a23c3b192d48dbbc8f14..b8db1b2144978e97bd32f62e643c2c4a7fcf1654 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -95,6 +95,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -199,7 +200,7 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ - ":test_op.so", + ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], kernels = [":test_op_kernel"], @@ -218,6 +219,7 @@ tf_cuda_cc_test( "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/compiler/jit", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", @@ -284,8 +286,8 @@ tf_cc_test( ) tf_custom_op_library( - name = "test_op.so", - srcs = ["test_op.cc"], + name = "test_op1.so", + srcs = ["test_op1.cc"], ) tf_kernel_library( diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 4540dcd6638a58c25628dccd2fa78f1fe06bef1d..f13e8777dff164bcd8eedf46310ae846abd0c804 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2810,4 +2810,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { } return ret; } + +// TF_Server functions ---------------------------------------------- + +#ifndef __ANDROID__ +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // __ANDROID__ + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Start(); +#endif +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Stop(); +#endif +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Join(); +#endif +} + +const char* TF_ServerTarget(TF_Server* server) { +#ifdef __ANDROID__ + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { delete server; } + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index da8ad1cec59e328b9f1a77f81416651a618e97d3..3d56268110edbe96616201d15a69cc8c84d3115a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1668,6 +1668,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 95652a11378d6276b5ba6540a07baa15aa77cc1c..5ba26d3c585350aa510f9970cbfc246a9a108543 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -25,6 +25,7 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -179,6 +180,15 @@ struct TF_ApiDefMap { tensorflow::mutex lock; }; +#ifndef __ANDROID__ +struct TF_Server { + TF_Server(std::unique_ptr server); + + const tensorflow::string target; + std::unique_ptr server; +}; +#endif + namespace tensorflow { class TensorCApi { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index b0dc0363fdb266a7bb8babcd41ac469b5e763551..d5934a10395ae094f65d3bc8b6cd7b94dbd32410 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -187,15 +187,26 @@ TEST(CAPI, LibraryLoadFunctions) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - // Load the library. - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - TF_Code code = TF_GetCode(status); - string status_msg(TF_Message(status)); - TF_DeleteStatus(status); - ASSERT_EQ(TF_OK, code) << status_msg; +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op1.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi1", op_list.op(0).name()); + TF_DeleteLibraryHandle(lib); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) { TF_Buffer* op_list_buffer = TF_GetAllOpList(); tensorflow::OpList op_list; @@ -210,19 +221,6 @@ TEST(CAPI, LibraryLoadFunctions) { EXPECT_TRUE(found); TF_DeleteBuffer(op_list_buffer); } - -#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) - { - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); - } -#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) - - TF_DeleteLibraryHandle(lib); } void TestEncodeDecode(int line, const std::vector& data) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3ee31a6a7ac641bbd3fc4c05568b61e433a1d523..ba3d8533db7623b8fa7fdf35093abcd1450776b1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -69,7 +69,7 @@ tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3554ec0bf3202b54bfc38d67e51b89df19832302..408277468d7beb23d1b2ab7f9bbccac16332e55a 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -404,8 +404,7 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { "The passed in handle is a nullptr"); return nullptr; } - tensorflow::Device* d = nullptr; - status->status = h->handle->OpDevice(&d); + tensorflow::Device* d = h->handle->op_device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 5006b76f1981d068e99a2c081115ebb3a66d8c7f..52b0824552855860dfb138f3ac9a5d3afa7dc965 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } - tensorflow::Device* device; - status->status = handle->handle->Device(&device); - if (!status->status.ok()) { - return nullptr; - } - #ifdef TENSORFLOW_EAGER_USE_XLA + tensorflow::Device* device = handle->handle->device(); + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. tensorflow::XlaDevice* xla_device = dynamic_cast(device); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 104d52430cf7aa14d4d2a335a1b96e667f21ce87..fa1b22e3af487b19b8b7885b7c3740b6249c73eb 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -79,10 +79,6 @@ struct TFE_TensorHandle { tensorflow::Device* op_device) : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} - TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - tensorflow::EagerContext* ctx) - : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; diff --git a/tensorflow/c/test_op1.cc b/tensorflow/c/test_op1.cc new file mode 100644 index 0000000000000000000000000000000000000000..b22cc9aef2b344282f45340ff12ee849935a26f9 --- /dev/null +++ b/tensorflow/c/test_op1.cc @@ -0,0 +1,23 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc"); + +} // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index c18b07603ae3841d3581741ab5a43f2e8b628356..83353b79f722f0a95f508b32d4a49b14b35624fb 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -170,6 +170,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -516,6 +517,8 @@ tf_gen_op_wrappers_cc( ":array_ops", ":const_op", ":math_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", ], ) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 6c29f09cde7ee17c11cb44ce48d8e9128daae4d0..16151e77737429f4fbf690fc34b12a70bacebdc4 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -93,7 +93,7 @@ cc_library( ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348c5cdfdcaed635ba527e9f0bfd6092d..d548de8c44285f6d21dd778db464a31e1b19645b 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) { return errors::InvalidArgument("Must specify --cpp_class"); } codegen_opts.gen_hlo_profile_printer_data = - xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + xla::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); @@ -132,7 +132,7 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0c41e095c7bc73e517d4c11c590a21439db1e3da..5f25e4626ad1cc3510b2508574ca34c29bdf20ce 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -21,7 +21,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") @@ -52,6 +51,7 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -65,6 +65,7 @@ cc_library( ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), alwayslink = 1, @@ -190,6 +191,7 @@ cc_library( "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:stack", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", @@ -241,6 +243,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) @@ -253,6 +256,7 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:core_cpu", @@ -263,6 +267,21 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "xla_compilation_cache_test", + srcs = [ + "xla_compilation_cache_test.cc", + ], + deps = [ + ":xla_compilation_cache", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -500,6 +519,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -524,25 +544,6 @@ cc_library( hdrs = ["union_find.h"], ) -cc_library( - name = "producer_consumer_queue", - hdrs = ["producer_consumer_queue.h"], - deps = ["//tensorflow/core:lib"], -) - -tf_cc_test( - name = "producer_consumer_queue_test", - size = "small", - srcs = ["producer_consumer_queue_test.cc"], - deps = [ - ":producer_consumer_queue", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -606,6 +607,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -648,31 +650,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "xla_launch_util_test", - size = "small", - srcs = ["xla_launch_util_test.cc"], - deps = [ - ":common", - ":xla_compilation_cache", - ":xla_launch_util", - ":xla_tensor", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core/kernels:variable_ops", - ], -) - cc_library( name = "xla_fusion_optimizer", srcs = ["xla_fusion_optimizer.cc"], diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 054f31ba3352b2215e6b0448c8ec8a70cb98b8e5..93637a69d5d7b6bf9e9ce784ae521ef0e9b121b9 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -214,7 +214,8 @@ Status NodeRequiresCompilation(Node* n, bool* result) { return errors::Internal("Could not find compilation device ", device_type.type()); } - *result = registration->requires_compilation; + *result = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 617e31488c7daeb714c0ff7056b786e4eaf7873f..8a73101c184e6190921fd7729742922bd96f4bcf 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output loop_cond = ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), latch.output_true, increment_by); Output next_iteration = @@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue( value, frame_name); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output next_iteration = ops::NextIteration( root.WithOpName(prefix + "/next_iteration"), latch.output_true); CHECK(root.graph() diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index a3b193eea745d4e44781225130216253c19371da..5e0c4bf6a0cc92d69209595e257989665404db6b 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -117,6 +117,25 @@ Status PreprocessForEncapsulation(Graph* g, // Information for XLA computation. struct XlaClusterInfo { + // Add an explicitly-defined default constructor for this class. + // + // The compiler may delete the default constructor here because + // host_compute_core is a const member whose type (std::map) doesn't + // necessarily have a user provided constructor -- while libc++ and + // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at + // least >= 7.3 does not. See also c++11 [class.ctor] p5. + // + // TODO(klimek): In c++17 we'll be able to initialize host_compute_core + // without losing aggregate initialization, which allows us to get rid of + // the constructor definitions again. + XlaClusterInfo() {} + XlaClusterInfo(const string& cluster_name, + const NameAttrList& func_name_attrs, Node* node, + const std::map& host_compute_core) + : cluster_name(cluster_name), + func_name_attrs(func_name_attrs), + node(node), + host_compute_core(host_compute_core) {} // XLA cluster name. It might be different from `func_name`. const string cluster_name; // Name and attributes of XLA computation function. diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 70b019d35fc80c975bc23ef42d61e3e36e4d0924..8b3587c5087a0651c466f53f3709ba21e75dd273 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -394,12 +394,12 @@ Status ConstructHostGraph( for (const string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; FunctionBody* host_fbody = nullptr; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fld->Find(host_func), AttrSlice(), fld, - [&](const string& op, const OpDef** sig) { - return fld->LookUpOpDef(op, sig); - }, - &host_fbody)); + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_func), AttrSlice(), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &host_fbody)); std::unique_ptr host_fbody_deleter(host_fbody); // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse @@ -411,52 +411,53 @@ Status ConstructHostGraph( node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); Status s; - ReverseDFS(*host_fbody->graph, /*enter=*/nullptr, - [&](const Node* n) { - if (!s.ok()) { - return; - } - - Node* copy; - if (node_map.find(n) != node_map.end()) { - // Already copied this node. - copy = node_map.at(n); - } else if (IsKeyPlaceholderNode(*n)) { - // Change a). - copy = key_placeholder; - node_map[n] = copy; - } else { - // Copy the node. - NodeDef copy_def = n->def(); - // Change c). - copy_def.clear_device(); - copy = (*host_graph)->AddNode(copy_def, &s); - if (!s.ok()) { - return; - } - node_map[n] = copy; - } - - // Only handle input edges. Output edges will be added later as - // its output nodes' input edges. - for (auto e : n->in_edges()) { - if (node_map.find(e->src()) == node_map.end()) { - s = errors::Internal("Cannot find node image for ", - e->src()->DebugString()); - return; - } - (*host_graph) - ->AddEdge(node_map[e->src()], e->src_output(), copy, - e->dst_input()); - } - - // Change b). - if (copy->type_string() == "_XlaRecvAtHost" || - copy->type_string() == "_XlaSendFromHost") { - (*host_graph)->AddControlEdge(copy, sequencer); - } - }, - NodeComparatorID()); + ReverseDFS( + *host_fbody->graph, /*enter=*/nullptr, + [&](const Node* n) { + if (!s.ok()) { + return; + } + + Node* copy; + if (node_map.find(n) != node_map.end()) { + // Already copied this node. + copy = node_map.at(n); + } else if (IsKeyPlaceholderNode(*n)) { + // Change a). + copy = key_placeholder; + node_map[n] = copy; + } else { + // Copy the node. + NodeDef copy_def = n->def(); + // Change c). + copy_def.clear_device(); + copy = (*host_graph)->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + node_map[n] = copy; + } + + // Only handle input edges. Output edges will be added later as + // its output nodes' input edges. + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + (*host_graph) + ->AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); + } + + // Change b). + if (copy->type_string() == "_XlaRecvAtHost" || + copy->type_string() == "_XlaSendFromHost") { + (*host_graph)->AddControlEdge(copy, sequencer); + } + }, + NodeComparatorID()); if (!s.ok()) { return s; } @@ -838,7 +839,12 @@ Status ExtractOutsideCompilationForFunction( FunctionDef shape_inference_fdef = *xla_fdef; shape_inference_fdef.mutable_signature()->set_name( shape_inference_graph); - TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + if (fld->Find(shape_inference_graph)) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + } } } } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index bd8719b7f1acb79e0b0cd91f2f0de0d66d8dab46..d984ca15cb722821b2a466a90387a29cbc1d1097 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" @@ -34,14 +35,30 @@ limitations under the License. namespace tensorflow { namespace { -Status GetTensorFromConstOp(Node* n, Tensor* out_tensor) { - TF_RET_CHECK(n->type_string() == "Const"); + +// StatusOrOptional instances hold +// +// - A non-OK Status to indicate an error that needs to be propagated out of +// this pass (e.g. the Graph is malformed). +// +// - A nullopt to indicate the function that created the instance failed to do +// what it set out to do but this is not actually an error +// (e.g. TryToGetTensorFromConstOp was passed a non-Const node). +// +// - A T to indicate a successful operation. +template +using StatusOrOptional = xla::StatusOr>; + +StatusOrOptional TryToGetTensorFromConstOp(Node* n) { + if (n->type_string() != "Const") { + return {absl::nullopt}; + } + const TensorProto* proto = nullptr; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); Tensor tensor(proto->dtype()); TF_RET_CHECK(tensor.FromProto(*proto)); - *out_tensor = std::move(tensor); - return Status::OK(); + return {tensor}; } struct SliceInputs { @@ -70,7 +87,7 @@ std::vector IntTensorAsVector(const Tensor& t) { // Packages up the inputs to a Slice operation into an instance of // `SliceInputs`. -Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { +StatusOrOptional GetSliceInputs(Node* slice) { const int kSliceInputIndex = 0; const int kSliceBeginIndex = 1; const int kSliceSizeIndex = 2; @@ -81,23 +98,27 @@ Status GetSliceInputs(Node* slice, SliceInputs* slice_inputs) { TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); const Edge* slice_begin_edge; TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); - slice_inputs->input = + + SliceInputs slice_inputs; + slice_inputs.input = Output(slice_input_edge->src(), slice_input_edge->src_output()); - slice_inputs->begin = + slice_inputs.begin = Output(slice_begin_edge->src(), slice_begin_edge->src_output()); - slice_inputs->size = + slice_inputs.size = Output(slice_size_edge->src(), slice_size_edge->src_output()); - Tensor tf_slice_size; - TF_RETURN_IF_ERROR( - GetTensorFromConstOp(slice_inputs->size.node(), &tf_slice_size)); + TF_ASSIGN_OR_RETURN(absl::optional tf_slice_size, + TryToGetTensorFromConstOp(slice_inputs.size.node())); + if (!tf_slice_size.has_value()) { + return {absl::nullopt}; + } - if (tf_slice_size.dims() != 1) { - return errors::Internal("Expected vector for the slice size input."); + if (tf_slice_size->dims() != 1) { + return {absl::nullopt}; } - slice_inputs->size_as_vector = IntTensorAsVector(tf_slice_size); - return Status::OK(); + slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size); + return {slice_inputs}; } // Casts `x` to a DT_INT64 if it isn't one already. @@ -263,36 +284,43 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, return Status::OK(); } -// Returns true if `n` is a slice we can rewrite to have a static shape -// (i.e. have the output shape only depend on the "size" input). Fills in -// `slice_inputs` in the process. -bool IsRewritableSlice(Node* n, SliceInputs* slice_inputs) { +// If `n` is a slice we can rewrite to have a static shape (i.e. have the output +// shape only depend on the "size" input) then returns the a SliceInputs +// representing the inputs to `n`. Otherwise returns nullopt. +StatusOrOptional IsRewritableSlice(Node* n) { if (n->type_string() != "Slice") { - return false; + return {absl::nullopt}; } if (!GetXlaClusterForNode(*n).has_value()) { // There is no need to change slice ops outside XLA clusters. - return false; + return {absl::nullopt}; } - if (!GetSliceInputs(n, slice_inputs).ok()) { - // Could not parse slice inputs. E.g. the sizes input was not a constant. - return false; + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + if (!slice_inputs.has_value()) { + return {absl::nullopt}; } // If slice_size[i] < -1 for any i then executing the slice will throw an // error, and we don't do anything here. - return absl::c_all_of(slice_inputs->size_as_vector, - [](int64 size_i) { return size_i >= -1; }); + bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); + if (!slice_is_ok) { + return {absl::nullopt}; + } + + return slice_inputs; } Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector> slices_to_rewrite; for (Node* n : g->nodes()) { - SliceInputs slice_inputs; - if (IsRewritableSlice(n, &slice_inputs)) { - slices_to_rewrite.push_back({n, std::move(slice_inputs)}); + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + IsRewritableSlice(n)); + if (slice_inputs.has_value()) { + slices_to_rewrite.push_back({n, std::move(*slice_inputs)}); } } diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 107d521077c3fe2ac72d113d46e2566c78c9fafb..f79bdc1e2e8d82c9144d1bb9923ad36d8541cbdb 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -44,11 +44,8 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); -// TODO(b/111210515): IncreaseDynamismForAutoJitPass creates slices with index -// type DT_INT64 which do not have a kernel on GPU. -// -// REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, -// IncreaseDynamismForAutoJitPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + IncreaseDynamismForAutoJitPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 6bcae1dcc3dcf87faa5317e0064c4c0cf80af465..055de7afcc538a1a1183f3687d998a5b2211c887 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -39,12 +39,22 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + namespace tensorflow { namespace { -Status PlatformInfoFromContext(OpKernelConstruction* ctx, - XlaPlatformInfo* result) { +XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; @@ -76,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx, } if (!device_allocator) { - TF_ASSIGN_OR_RETURN(se::Platform* const platform, - se::MultiPlatformManager::PlatformWithId(platform_id)); + xla::StatusOr maybe_platform = + se::MultiPlatformManager::PlatformWithId(platform_id); + OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); + xla_allocator = absl::make_unique( - platform, ctx->device()->GetAllocator({})); + maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } - *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); - - return Status::OK(); + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -179,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, : OpKernel(ctx), constants_(constants), resources_(resources), - function_(function) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(function), + platform_info_(PlatformInfoFromContext(ctx)) {} static Status BuildCompilationCache(OpKernelContext* ctx, const XlaPlatformInfo& platform_info, @@ -277,8 +286,10 @@ static Status CompileToLocalExecutable( // rather than a one-element tuple. compile_options.always_return_tuple = false; - return cache->Compile(options, function, constant_args, *variables, ctx, - compile_options, + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_args, *variables, ctx, &args)); + return cache->Compile(options, function, args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, kernel, executable); @@ -333,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - // Helper static functions to construct parameters for // XlaLocalLaunchBase constructor from OpKernelConstruction. std::vector ConstantsVector(OpKernelConstruction* ctx) { @@ -381,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) { return *func; } -#undef OP_REQUIRES_OK_RETURN +bool MustCompileAttr(OpKernelConstruction* ctx) { + bool must_compile; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr("must_compile", &must_compile)); + return must_compile; +} } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -396,10 +400,9 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx), constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), - function_(FunctionAttr(ctx)) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("must_compile", &must_compile_)); -} + function_(FunctionAttr(ctx)), + platform_info_(PlatformInfoFromContext(ctx)), + must_compile_(MustCompileAttr(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaCompileOp " << def().name() @@ -409,13 +412,30 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; std::map variables; - if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation) { + bool cannot_compile_cluster; + { + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster = cannot_compile_cluster_; + } + + if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + cannot_compile_cluster) { executable = nullptr; } else { - OP_REQUIRES_OK(ctx, CompileToLocalExecutable( - ctx, function_, platform_info_, resources_, - constants_, /*lazy=*/!must_compile_, &client, - &variables, &kernel, &executable)); + Status status = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, + /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + OP_REQUIRES_OK(ctx, status); + } + + if (status.code() == error::UNIMPLEMENTED) { + LOG(WARNING) << "Compilation failed:" << status.ToString() + << ". Falling back to TF function call."; + executable = nullptr; + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster_ = true; + } } AllocatorAttributes host_alloc_attrs; @@ -452,9 +472,8 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(1, compilation_successful); } -XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index ac90837e0d90943b93e2cdb01a30fa0837ba94df..7b4d4b5b4737784d4fe277d5bbe9cab79cfaf4c9 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#include + #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" @@ -33,6 +35,7 @@ namespace tensorflow { class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; explicit XlaPlatformInfo(const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, @@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel { protected: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; - XlaPlatformInfo platform_info_; + const NameAttrList function_; + const XlaPlatformInfo platform_info_; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -144,15 +147,23 @@ class XlaCompileOp : public OpKernel { private: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; + const NameAttrList function_; XlaPlatformInfo platform_info_; - bool must_compile_; + const bool must_compile_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + + mutex cannot_compile_cluster_mu_; }; class XlaRunOp : public OpKernel { @@ -162,7 +173,7 @@ class XlaRunOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaPlatformInfo platform_info_; + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 49ff9a3ddd1fc14ba59209c39e00856986deab2d..5fa6c85f06f863f5d18bc4939ffa0ae820d222bd 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -22,7 +22,7 @@ cc_library( hdrs = ["mark_for_compilation_pass_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -34,7 +34,7 @@ cc_library( hdrs = ["xla_device_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -46,7 +46,7 @@ cc_library( hdrs = ["build_xla_ops_pass_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -58,7 +58,7 @@ cc_library( hdrs = ["xla_ops_common_flags.h"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc index 73f4dc73ed83e2d1e89ccd6c99970d46b5767104..961c17c17eac891261530ef25baaa50f8496c331 100644 --- a/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.cc @@ -16,7 +16,7 @@ limitations under the License. #include // NOLINT #include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -34,7 +34,7 @@ void AllocateAndParseFlags() { Flag("tf_xla_enable_lazy_compilation", &flags->tf_xla_enable_lazy_compilation, ""), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } } // namespace diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 7277a1d1f8ad5fa045645ead839ab9efa01e89c7..bad306e0b0a3061ba13dc69c08066c642667a2b9 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -19,7 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -64,7 +65,18 @@ static void AllocateFlags() { Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, "enable fusion of element-wise operations only using XLA when " "global_jit_level is ON*.")}); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Parsed MarkForCompilationPassFlags:"; + VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit; + VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size; + VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size; + VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug; + VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel; + VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + } } // Append to *append_to flag definitions associated with the XLA bridge's diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 2affda6ab4e0fbad32a246744fa5b38aeb629c1b..79b47357a179d2d9e0d1b6bf9c9f814288bcd5e1 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -33,7 +33,7 @@ void AppendMarkForCompilationPassFlags( // The values of flags associated with the XLA bridge's // mark_for_compilation_pass module. -typedef struct { +struct MarkForCompilationPassFlags { int32 tf_xla_auto_jit; // Control compilation of operators into XLA // computations on CPU and GPU devices. 0 = use // ConfigProto setting; -1 = off; 1 = on for things @@ -55,7 +55,7 @@ typedef struct { // is set to ON* and overrides its behavior. If // true, enable fusion of element-wise operations // only using XLA. -} MarkForCompilationPassFlags; +}; // Return a pointer to the MarkForCompilationPassFlags struct; // repeated calls return the same pointer. diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc index 1bb2fce2dbad5bffce2e33b665b7222090d0855a..76b80d3034c8a13a1ddf1afe548d5c3d9c7b2cec 100644 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -41,7 +41,7 @@ static void AllocateFlags() { "Switch a device into 'on-demand' mode, where instead of " "autoclustering ops are compiled one by one just-in-time."), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } // Return a pointer to the XlaDeviceFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc index ae17fdffb9b6a574449b7f3155e050b029702db7..1443d48a734c0a44c1cd91d8d1218bdbed7f765c 100644 --- a/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" - +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -35,7 +35,13 @@ void AllocateAndParseFlags() { Flag("tf_xla_always_defer_compilation", &flags->tf_xla_always_defer_compilation, ""), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); + + if (VLOG_IS_ON(1)) { + VLOG(1) << "Parsed XlaOpsCommonFlags:"; + VLOG(1) << " tf_xla_always_defer_compilation = " + << flags->tf_xla_always_defer_compilation; + } } const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11975a6bb07e03dc3d182beb3748eb2559de7e25..70033cae0afacb6a25598ee1abf2aeb2721e7496 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -61,8 +61,23 @@ struct OperationFilter { // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid // auto-clustering stateful RNG ops. bool allow_stateful_rng_ops; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger; + + // Whether ops with dummy implementations are allowed. We avoid + // auto-clustering these ops so that the user is not surprised when XLA is + // implicitly enabled. If the user explicitly specifies to use XLA, it is fine + // to resort to a dummy implementation. Currently Assert and CheckNumerics ops + // have dummy XLA implementations. + bool allow_dummy_ops; }; +bool IsDummyImplOp(absl::string_view op_name) { + return op_name == "Assert" || op_name == "CheckNumerics"; +} + bool IsStatefulRandomOp(absl::string_view op_name) { return op_name == "RandomUniform" || op_name == "RandomShuffle" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || @@ -225,6 +240,12 @@ bool IsCompilableCall(const NodeDef& call_def, IsStatefulRandomOp(node->type_string())) { return false; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + return false; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, lib_runtime)) { @@ -452,7 +473,14 @@ Status FindCompilationCandidates( OperationFilter op_filter; op_filter.allow_resource_ops = registration->compile_resource_ops; - op_filter.allow_stateful_rng_ops = registration->requires_compilation; + op_filter.allow_stateful_rng_ops = + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_control_trigger = + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); + op_filter.allow_dummy_ops = (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, @@ -467,6 +495,15 @@ Status FindCompilationCandidates( VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; continue; } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op"; + continue; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": dummy op (" + << node->type_string() << ")"; + continue; + } if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { @@ -597,11 +634,14 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - // We can always *compile* resource operations and stateful RNGs, even if we - // are sometimes unable to auto-cluster them. + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. OperationFilter op_filter; op_filter.allow_resource_ops = true; op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_dummy_ops = true; + return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } @@ -613,10 +653,8 @@ Status MarkForCompilationPass::Run( GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); - bool cpu_global_jit = flags->tf_xla_cpu_global_jit; bool fusion_only = flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; @@ -635,9 +673,6 @@ Status MarkForCompilationPass::Run( return false; } - // If this device requires a JIT, we must say yes. - if (registration->requires_compilation) return true; - // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); @@ -674,18 +709,21 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level. - // Ignore enable_jit_by_default if global jit compilation for CPU - // is explicitly requested via tf_xla_cpu_global_jit flag - bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + // Otherwise use the value of global_jit_level and the device's + // autoclustering policy. bool should_compile = - (ignore_registration || registration->enable_jit_by_default) && - global_jit_level != OptimizerOptions::OFF; + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); if (!should_compile) { if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { - VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + VLOG(2) + << "Rejecting " << node->name() + << ": autoclustering for device only when requested explicitly."; } } return should_compile; @@ -1073,12 +1111,10 @@ Status MarkForCompilationPass::RunImpl( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if the operator is placed on a device that requires - // compilation, or if it contains at least one op that is marked for + // Also, always compile if it contains at least one op that is marked for // compilation that is not an Identity op. if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || - registration->requires_compilation) { + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { string& name = cluster_names[cluster]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index ead1cf4fd5faff649e8518aaeb95935ccef4ca52..24d78c077268f83cebbdafddc1a658ae8dc6b8d8 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -817,14 +817,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unordered_map clusters = GetClusters(*graph); - ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; - - // ctrl_trigger_a has inputs with mismatching deadness so it won't be - // clustered. ctrl_trigger_b is okay to cluster. - std::unordered_map expected_clusters( - {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); - EXPECT_EQ(clusters, expected_clusters); + // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so + // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't + // cluster it because of b/118970344. + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, RandomShape) { @@ -923,9 +919,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); - EXPECT_NE(clusters["test/shape_rng"], ""); - EXPECT_NE(clusters["test/reshape"], ""); - EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); + EXPECT_EQ(clusters["test/shape_rng"], ""); + EXPECT_EQ(clusters["test/reshape"], ""); } TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { @@ -1088,7 +1083,7 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { EXPECT_NE(clusters["test/c"], ""); } -TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { +TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { Scope root = Scope::NewRootScope().ExitOnError(); Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); @@ -1104,5 +1099,53 @@ TEST(XlaCompilationTest, DontAutoclusterStatefulRandomOp) { EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/b"], ""); } + +TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/check"], ""); + EXPECT_NE(clusters["test/greaterequal"], ""); + EXPECT_NE(clusters["test/assert"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterDummyOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/assert"], ""); + EXPECT_EQ(clusters["test/check"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 5b9610322336acbcede0bef0538043b8ff917c16..36b345ecbff8d5f6ba3c241b9e164f677236c20d 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -133,6 +133,10 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveEdge(out_edge_to_clone); } + if (n->out_edges().empty()) { + graph->RemoveNode(n); + } + return Status::OK(); } @@ -191,6 +195,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { } } + // Recompute post order since PartiallyDeclusterNode may have deleted nodes. + post_order.clear(); + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); nodes_to_partially_decluster.clear(); TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); @@ -210,7 +218,8 @@ bool IsIntraClusterEdge(const Edge& edge) { bool IsMustCompileDevice(const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return registration->requires_compilation; + return registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; } return false; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 74d5ef57184197ad6e9e5048722e84863756a3f5..1fc5da5071f7aa6f6dd6636aacd60e33c12431a6 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -437,5 +437,32 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); } +TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { + const char* const kClusteredProducer0Name = "ClusteredProducer0"; + const char* const kClusteredProducer1Name = "ClusteredProducer1"; + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName(kClusteredProducer0Name)); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName(kClusteredProducer1Name)); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h deleted file mode 100644 index 7c8c04152d2f3a0fd46711df24756b7e68b967ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ -#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ - -#include -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// A thread-safe, first-in-first-out queue. -template -class ProducerConsumerQueue { - public: - ProducerConsumerQueue() - : capacity_(std::numeric_limits::max()) {} - ~ProducerConsumerQueue() = default; - - // Wait until the queue is non-full, then append a copy of v. - void Put(const T &v); - - // Wait until the queue is non-empty, then remove and return the head value. - T Get(); - - // If the queue is non-empty, remove the head value, placing it in *pv, and - // return true; otherwise return false. - bool TryGet(T *pv); - - // Set the capacity of the queue; the queue is full whenever count() >= - // capacity(). The initial value is the maximum size_t. Requires size > 0. - void set_capacity(std::size_t size); - - // Return the capacity of the queue. - std::size_t capacity() const; - - // Return the number of elements in the queue. - std::size_t count() const; - - // Implementation details follow. Clients should ignore. - private: - mutable tensorflow::mutex mu_; // protects all fields below - tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); - tensorflow::condition_variable non_full_ GUARDED_BY(mu_); - std::size_t capacity_ GUARDED_BY(mu_); - std::deque queue_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); -}; - -// ------------------------------------------------------ -// Implementation details follow. Clients should ignore. - -// Wait until the queue is non-full, then append a copy of v. -template -void ProducerConsumerQueue::Put(const T &v) { - mutex_lock lock(mu_); - while (queue_.size() >= capacity_) { - non_full_.wait(lock); - } - queue_.push_back(v); - non_empty_.notify_one(); -} - -// Wait until the queue is non-empty, then remove and return the head value. -template -T ProducerConsumerQueue::Get() { - mutex_lock lock(mu_); - while (queue_.empty()) { - non_empty_.wait(lock); - } - non_full_.notify_one(); - T result_value = queue_.front(); - queue_.pop_front(); - return result_value; -} - -// If the queue is non-empty, remove the head value, placing it in *pv, and -// return true; otherwise return false. -template -bool ProducerConsumerQueue::TryGet(T *pv) { - mutex_lock lock(mu_); - bool got_element = !queue_.empty(); - if (got_element) { - non_full_.notify_one(); - *pv = queue_.front(); - queue_.pop_front(); - } - return got_element; -} - -// Set the capacity of the queue; the queue is full whenever count() >= -// capacity(). The initial value is the maximum size_t. Requires size > 0. -template -void ProducerConsumerQueue::set_capacity(std::size_t size) { - mutex_lock lock(mu_); - CHECK_NE(size, 0); - capacity_ = size; - non_full_.notify_all(); -} - -// Return the capacity of the queue. -template -std::size_t ProducerConsumerQueue::capacity() const { - mutex_lock lock(mu_); - std::size_t max_elements = capacity_; - return max_elements; -} - -// Return the number of elements in the queue. -template -std::size_t ProducerConsumerQueue::count() const { - mutex_lock lock(mu_); - std::size_t num_elements = queue_.size(); - return num_elements; -} -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc deleted file mode 100644 index f61260c6e52756ee039829afdc7452f5f760c221..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/producer_consumer_queue.h" - -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -typedef ProducerConsumerQueue IntQueue; - -// Insert integers between low inclusive and high exclusive into q. -void PushRange(IntQueue *q, int low, int high) { - while (low != high) { - q->Put(low); - VLOG(2) << "Pushing " << low; - ++low; - } -} - -// Push the numbers between 0 and 999 inclusive from several threads in the -// pool. -void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { - VLOG(1) << "Adding 20-36"; - pool->Schedule([queue] { PushRange(queue, 20, 36); }); - VLOG(1) << "Adding 7-20"; - pool->Schedule([queue] { PushRange(queue, 7, 20); }); - VLOG(1) << "Adding 36-501"; - pool->Schedule([queue] { PushRange(queue, 36, 501); }); - VLOG(1) << "Adding 501-1000"; - pool->Schedule([queue] { PushRange(queue, 501, 1000); }); - VLOG(1) << "Adding 0-5"; - pool->Schedule([queue] { PushRange(queue, 0, 5); }); - VLOG(1) << "Adding 5-7"; - pool->Schedule([queue] { PushRange(queue, 5, 7); }); -} - -// Pop elements from queue using Get(). Make sure that exactly elements -// were present and their values are all integers between 0 and high-1 -// inclusive. -void GetRange(IntQueue *queue, int high) { - VLOG(1) << "Testing Wait"; - std::vector results; - for (int i = 0; i != high; ++i) { - int r = queue->Get(); - VLOG(2) << "Waited and got " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK(results[i] == i); - } -} - -// Pop elements from queue using TryGet(). Make sure that exactly -// elements were present and their values are all integers between 0 and high-1 -// inclusive. -void TryGetRange(IntQueue *queue, int high) { - std::vector results; - // Give up if we don't get all the elements back from the queue - // in 10 seconds. - int timeout = 10; - int r; - for (int i = 0; i != high; ++i) { - while (!queue->TryGet(&r)) { - if (!timeout--) { - LOG(FATAL) << "Can't find all elements in the queue"; - } - VLOG(1) << "Sleeping for a second..."; - sleep(1); - } - VLOG(2) << "Popped " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - CHECK(!queue->TryGet(&r)); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK_EQ(i, results[i]); - } -} - -const int kNumThreads = 15; - -TEST(ProducerConsumerQueue, GetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - GetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, TryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - TryGetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, ParallelGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { GetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -TEST(ProducerConsumerQueue, ParallelTryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 4a5ea9e0a5f8cf79478069931da598099ae4e716..3df5479a55e841380ca7b8cdd0add9fd17487091 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -65,14 +66,14 @@ string XlaCompilationCache::DebugString() { // Compute a string signature which encodes the shapes of the // arguments in the supplied list. -string XlaCompilationCache::SignatureDebugString(const Signature& sig) { - string result = sig.name; - for (const auto& a : sig.arg_types) { +string XlaCompilationCache::Signature::HumanString() const { + string result = name; + for (const auto& a : arg_types) { absl::StrAppend(&result, ",", DataTypeString(a.first), a.second.DebugString()); } - for (const auto& v : sig.arg_values) { + for (const auto& v : arg_values) { absl::StrAppend(&result, "; ", v.DebugString()); } return result; @@ -84,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { - if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { + if (arg_values[i].dtype() != other.arg_values[i].dtype() || + arg_values[i].shape() != other.arg_values[i].shape() || + arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { return false; } } @@ -108,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( return h; } -Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - Signature* signature) { - signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.reserve(constant_args.size()); - - signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); - - for (int i = 0; i < ctx->num_inputs(); ++i) { - if (constant_args.count(i) > 0) { - // Use the values of compile time constants in the signature. - signature->arg_values.push_back(constant_args.at(i)); - } else if (variable_args.count(i) > 0) { - const OptionalTensor& variable = variable_args.at(i); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); - } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); - } - } else { - signature->arg_types.emplace_back(ctx->input_dtype(i), - ctx->input(i).shape()); - } - } - return Status::OK(); -} - -namespace { - -// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. -Status BuildArguments(const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); - - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; - if (constant_args.count(input_num) > 0) { - // Handles compile-time constants. - const Tensor& input = constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_args.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { - // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - const OptionalTensor& variable = variable_args.at(input_num); - arg.name = variable.name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable.present) { - const Tensor& value = variable.value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; - } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do - // permit uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); - } +xla::StatusOr +XlaCompilationCache::BuildSignature( + const NameAttrList& function, + absl::Span args) { + Signature signature; + signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); + for (const XlaCompiler::Argument& arg : args) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + signature.arg_values.push_back(arg.constant_value); + break; + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kResource: + signature.arg_types.emplace_back(arg.type, arg.shape); + break; + default: + return errors::InvalidArgument( + "Unhandled argument kind in XlaCompilationCache: ", + arg.HumanString()); } } - - return Status::OK(); + return std::move(signature); } -} // namespace - Status XlaCompilationCache::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, @@ -227,25 +164,38 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { - // Set the compile threshold to 1 to implement CompileMode::kStrict. - int64 compile_threshold = - compile_mode == CompileMode::kLazy ? kDefaultCompilationThreshold : 1; - return CompileImpl(options, function, constant_args, variable_args, ctx, - compile_options, /*compile_single_op=*/false, + absl::optional compile_threshold; + if (compile_mode == CompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } + auto compile_fn = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + return compiler->CompileFunction(compile_options, function, args, result); + }; + return CompileImpl(options, function, args, compile_fn, /*compile_threshold=*/compile_threshold, out_compilation_result, out_executable); } +static bool IsMegamorphic(int64 compile_count, int64 execution_count) { + const int64 kCompileThreshold = 10; + const int64 kMinExecutionsPerCompile = 50; + + // This heuristic is trying to capture the following property: have we sunk a + // certain minimum amount of compile time into the cluster that didn't quite + // "pay off"? + return compile_count > kCompileThreshold && + execution_count < kMinExecutionsPerCompile * compile_count; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { @@ -253,54 +203,41 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl(options, name, constant_args, variable_args, ctx, - compile_options, - /*compile_single_op=*/true, /*compile_threshold=*/1, + auto compile_op = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + std::vector result_dtypes(ctx->num_outputs()); + for (int i = 0; i < result_dtypes.size(); ++i) { + result_dtypes[i] = ctx->expected_output_dtype(i); + } + return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), + args, result_dtypes, result); + }; + return CompileImpl(options, name, args, compile_op, + /*compile_threshold=*/absl::nullopt, out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, - int64 compile_threshold, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { - VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << constant_args.size() - << " num_variable_args=" << variable_args.size(); - for (int i = 0; i < ctx->num_inputs(); i++) { - TensorShape shape = ctx->input(i).shape(); - VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i)) - << " present=" << ctx->has_input(i) - << " shape=" << shape.DebugString(); - } - for (auto& iterator : variable_args) { - const OptionalTensor& variable = iterator.second; - VLOG(2) << "variable present=" << variable.present - << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString() - << " TF arg= " << iterator.first; - } - VLOG(2) << "num_outputs = " << ctx->num_outputs(); - for (int i = 0; i < ctx->num_outputs(); i++) { - VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0; i < args.size(); i++) { + VLOG(2) << i << ": " << args[i].HumanString(); } } - TF_RET_CHECK(constant_args.size() + variable_args.size() <= - ctx->num_inputs()); - - Signature signature; - TF_RETURN_IF_ERROR( - BuildSignature(function, constant_args, variable_args, ctx, &signature)); + TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); + VLOG(2) << "Signature: " << signature.HumanString(); - VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not // protect the contents of the cache entry. Entry* entry; @@ -319,25 +256,67 @@ Status XlaCompilationCache::CompileImpl( // (since they get the benefit of XLA right away without waiting for warmup) // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at // most one cluster-compilation's worth of compile time). - bool is_first_execution = [&] { + bool is_first_execution; + + // We avoid compiling clusters that have "gone megamorphic" i.e. have an + // excessive amount of shape dynamism. + bool is_megamorphic; + + { mutex_lock lock(cluster_compile_stats_mu_); auto it = cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) .first; - return it->second.execution_count++ == 0; - }(); + is_first_execution = it->second.execution_count++ == 0; + + // The is_megamorphic bit is "sticky". We assume clusters that have been + // observed to be megamorphic once stay megamorphic forever. + it->second.is_megamorphic |= + IsMegamorphic(/*compile_count=*/it->second.compile_count, + /*execution_count=*/it->second.execution_count); + is_megamorphic = it->second.is_megamorphic; + } // Acquire the cache entry lock and compile, if necessary. // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); int64 current_request_count = ++entry->request_count; + VLOG(2) << "Compilation cache entry hit: " << entry->compiled + << " signature: " << signature.HumanString() << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold.value_or(0); if (!entry->compiled) { - VLOG(2) << "Compilation cache miss for signature: " - << SignatureDebugString(signature) << " with request count " - << current_request_count << " and compile threshold " - << compile_threshold; - if (!is_first_execution && current_request_count < compile_threshold) { + const bool should_compile = [&] { + if (!compile_threshold.has_value()) { + // Lazy compilation is disabled. + return true; + } + + if (is_megamorphic) { + VLOG(3) << "Not compiling cluster " << function.name() + << " because it is megamorphic."; + return false; + } + + if (is_first_execution) { + return true; + } + + bool reached_compile_threshold = + current_request_count >= *compile_threshold; + if (!reached_compile_threshold) { + VLOG(3) + << "Not compiling cluster " << function.name() + << " because it has not reached compile threshold; threshold is " + << *compile_threshold << " execution count " + << current_request_count << "."; + } + return reached_compile_threshold; + }(); + + if (!should_compile) { + VLOG(2) << "Not compiling for signature: " << signature.HumanString(); *out_compilation_result = nullptr; *out_executable = nullptr; return Status::OK(); @@ -347,21 +326,12 @@ Status XlaCompilationCache::CompileImpl( const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) - std::vector args; - TF_RETURN_IF_ERROR( - BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - if (compile_single_op) { - entry->compilation_status = - compiler.CompileSingleOp(compile_options, signature.name, ctx, args, - &entry->compilation_result); - } else { - entry->compilation_status = compiler.CompileFunction( - compile_options, function, args, &entry->compilation_result); - } + entry->compilation_status = + compile_fn(&compiler, &entry->compilation_result); TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); entry->compilation_status = diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index b43e5d40e6402d24b80f7c689018d81e8a5d7f09..846d0c963dbfdf55f51120f2f138d12f5f63839b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -17,9 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/graph.pb.h" @@ -30,13 +33,6 @@ limitations under the License. namespace tensorflow { -// Struct that represents a possibly-absent Tensor. -struct OptionalTensor { - string name; // A descriptive name - bool present = false; // Is the tensor present? - Tensor value; // If present, what is the Tensor's value? -}; - // The XlaCompilationCache class caches the results of the XlaCompiler class, // which converts a Tensorflow graph into a compiled XLA compilation. // @@ -58,11 +54,7 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `constant_args` is a map of tensorflow argument number to its constant - // value. - // `variable_args` is a snapshot of the current values of the - // resource variable arguments to `function`; uninitialized variables are - // represented by an absent OptionalTensor. + // `args` is a description of the arguments to the computation. // // `compile_mode` controls the behavior of the compilation cache on a cache // miss. If `compile_mode` is `kLazy` then, based on some profitability @@ -78,9 +70,7 @@ class XlaCompilationCache : public ResourceBase { // outputs. Status Compile(const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, @@ -90,8 +80,7 @@ class XlaCompilationCache : public ResourceBase { // XlaCompiler::CompileFunction. Status CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -101,26 +90,6 @@ class XlaCompilationCache : public ResourceBase { string DebugString() override; - private: - // Common implementation of Compile and CompileSingleOp. - Status CompileImpl( - const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op, int64 compile_threshold, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // Takes `result` which has been compiled from a Tensorflow subgraph to a - // XLA computation already, and generates an XLA LocalExecutable `executable`. - Status BuildExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable); - - xla::LocalClient* const client_; - const DeviceType device_type_; - // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { @@ -137,14 +106,35 @@ class XlaCompilationCache : public ResourceBase { struct Hash { uint64 operator()(const Signature& signature) const; }; + + // Returns a human-readable description of the signature. + string HumanString() const; }; - static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, Signature* signature); + static xla::StatusOr BuildSignature( + const NameAttrList& function, + absl::Span args); + + private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); + + // Takes `result` which has been compiled from a Tensorflow subgraph to a + // XLA computation already, and generates an XLA LocalExecutable `executable`. + Status BuildExecutable(const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + std::unique_ptr* executable); + + xla::LocalClient* const client_; + const DeviceType device_type_; // The value associated with a cache entry. struct Entry { @@ -180,7 +170,13 @@ class XlaCompilationCache : public ResourceBase { // Cumulative time spent compiling the cluster. int64 cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; }; + mutex cluster_compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..018c7c219f445bdca17f4f8b060e3678fe1be9ee --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(XlaCompilationCacheTest, SignatureEquality) { + NameAttrList fn; + fn.set_name("afunction"); + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({4, 0}); + args[0].constant_value = Tensor(DT_INT32, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].type = DT_FLOAT; + args[0].constant_value = Tensor(DT_FLOAT, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].shape = TensorShape({0, 4}); + args[0].constant_value = Tensor(DT_FLOAT, {0, 4}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3, + XlaCompilationCache::BuildSignature(fn, args)); + + std::vector signatures = {s1, s2, s3}; + for (int i = 0; i < signatures.size(); ++i) { + for (int j = 0; j < signatures.size(); ++j) { + EXPECT_EQ(i == j, signatures[i] == signatures[j]) + << signatures[i].HumanString() << " " << signatures[j].HumanString(); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 31cb32e3059bc17e3cde36e5c9f90cc78a39e473..1fe612d43d10030675cf307b109e4dcc89cb2d79 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -187,8 +187,13 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); - return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - compile_options, result, executable); + + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arguments, variable_args, ctx, &args)); + + return cache->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index cbfeb38805038825917c16684b9c441818972042..116e0756036e722c13f27579aa0e0876d2e846a7 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -42,8 +42,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = !compile_on_demand; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + compile_on_demand + ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested + : XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2289abd2df372620c05db900bd46d1cdf6174377..5c1b55cb57f58387086ab9eaf924d0beffb43e18 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -446,7 +446,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - kernel_factory::OpKernelRegistrar::Factory factory = + OpKernel* (*factory)(OpKernelConstruction*) = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 8881b697bc863e58006361924f7761c2e5bba493..49f53b477ef5508a23812453cb61e29a8d8b9379 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -112,6 +112,12 @@ class XlaDevice : public LocalDevice { // compute, host-to-device, and device-to-host communication. bool use_multiple_streams = false; + // A function that describes how the on-host shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared + // shapes for computations. Must be non-null. XlaCompiler::ShapeRepresentationFn shape_representation_fn; // If padded_shape_fn is empty, a default implementation that returns diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index eb3cf27624bb76058c8f0cf2e999818434d38d9e..6e6532731e64bd42ee56aa719748988f321e0f17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -70,9 +70,12 @@ XlaDeviceContext::XlaDeviceContext( CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = - [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { return shape; }; + shape_representation_fn_ = [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -99,7 +102,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, CHECK(xla_tensor); Status status = [&]() -> Status { - TF_ASSIGN_OR_RETURN(TensorShape shape, + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn_(device_tensor->shape(), device_tensor->dtype())); @@ -111,9 +114,15 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal())); + // The cpu_tensor and literal that we created here hold the data of host + // tensor in descending layout. The layout could be different from layout in + // device_tensor (but the logical shape has to be the same). The + // transfer_manager is responsible to do corresponding transposing when + // transferring the data to device. xla::BorrowingLiteral literal( static_cast(DMAHelper::base(cpu_tensor)), - xla_tensor->shaped_buffer().on_host_shape()); + xla::ShapeUtil::MakeShape(shape.element_type(), + xla::AsInt64Slice(shape.dimensions()))); VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " << xla_tensor->shaped_buffer().ToString(); @@ -183,8 +192,15 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); + // Transfer manager requires the shape of the shaped buffer to be the same as + // literal shape except for the layout. Set the literal to use xla_tensor's + // shape as it is derived from the cpu_tensor's shape using + // shape_representation_fn_. xla::MutableBorrowingLiteral literal; - TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(cpu_tensor, &literal)); + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout( + xla_tensor->shaped_buffer().on_host_shape()), + cpu_tensor, &literal)); TensorReference ref(*device_tensor); transfer_manager_->TransferLiteralFromDevice( diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 241ea8f60df8b66a9a39e3e176ecd4119f27d780..adf0f994b84d9fbf918a5b2478aa7d106853e038 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/stack.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel { .Device(DEVICE) \ .TypeConstraint("T") \ .HostMemory("input"), \ - RetvalOp); + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("StackV2") \ + .Device(DEVICE) \ + .HostMemory("max_size") \ + .HostMemory("handle"), \ + StackOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("T", TYPES), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("elem_type", TYPES), \ + StackPopOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); -// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // and write the tensors they access in order to concatenate them into a batch. // We would need either to call out to an XLA computation to perform the diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 8f28b38b5e15052e9a14bd1ecf1b3047085d98f1..441970169581d53e0d8683b98d26712445b170ea 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -37,8 +37,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, std::vector* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); @@ -53,24 +53,25 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, return Status::OK(); } - XlaDevice::Options options; - options.platform = platform.ValueOrDie(); - options.device_name_prefix = name_prefix; - options.device_name = DEVICE_XLA_GPU; - options.device_ordinal = 0; - options.compilation_device_name = DEVICE_GPU_XLA_JIT; - options.use_multiple_streams = false; - auto device = absl::make_unique(session_options, options); - - // TODO(b/78468222): Uncomment after fixing this bug - // status = device->UseGpuDeviceInfo(); - // if (!status.ok()) { - // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, - // " device"); - // return status; - // } - - devices->push_back(device.release()); + for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) { + XlaDevice::Options options; + options.platform = platform.ValueOrDie(); + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_GPU; + options.device_ordinal = i; + options.compilation_device_name = DEVICE_GPU_XLA_JIT; + options.use_multiple_streams = true; + auto device = absl::make_unique(session_options, options); + + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + " device number ", i); + return status; + } + + devices->push_back(device.release()); + } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index dc37362fd8611577317f62083796fd4d655e7066..e828bae865d630bd40f227943cdabb2d8d95ca48 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -45,8 +45,8 @@ Status XlaInterpreterDeviceFactory::CreateDevices( XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, registration); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6e51bfca4a1504f8f11fe60159cb44b2ae19fa1b..3b0bda4caa161a7561a3098b89420329998ff8a7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -191,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { return Status::OK(); } -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -ScopedShapedBuffer ExtractSubShapedBuffer( - ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator) { - const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_host_shape(), index); - const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_device_shape(), index); - - ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, - shaped_buffer->platform(), - shaped_buffer->device_ordinal()); - - auto& shape_tree = shaped_buffer->buffers(); - auto& sub_shape_tree = sub_shaped_buffer.buffers(); - sub_shape_tree.CopySubtreeFrom(shape_tree, - /*source_base_index=*/{index}, - /*target_base_index=*/{}); - shape_tree.ForEachMutableElement( - [index](const xla::ShapeIndex& shape_index, - tensorflow::se::DeviceMemoryBase* data) { - // shape_index is empty for the root node. Ignore that. - if (!shape_index.empty() && shape_index[0] == index) { - *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); - } - }); - return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); -} -} // namespace internal -using internal::ExtractSubShapedBuffer; - XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) @@ -391,8 +357,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); if (use_multiple_streams_) { xla_tensor->ResetDefinitionEvent(definition_event, stream); } @@ -445,7 +410,6 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - se::DeviceMemoryBase buffer = output.buffer({output_num}); if (variable_infos[i].var()->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); @@ -455,18 +419,20 @@ Status XlaComputationLaunchContext::PopulateOutputs( Tensor output_tensor; TF_RETURN_IF_ERROR( ctx->allocate_temp(write.type, write.shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); - if (use_multiple_streams_) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); + if (write.shape.num_elements() > 0) { + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); + if (use_multiple_streams_) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); + } } *variable_infos[i].var()->tensor() = output_tensor; } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable_infos[i].var()->tensor() = output_tensor; } ++output_num; @@ -474,4 +440,60 @@ Status XlaComputationLaunchContext::PopulateOutputs( return Status::OK(); } +Status XlaComputationLaunchContext::BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args) { + args->resize(ctx->num_inputs()); + + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { + XlaCompiler::Argument& arg = (*args)[input_num]; + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); + arg.constant_value = input; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); + } else { + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 81e205d13f711a701026b82100c17423595919ed..437db019a0eabe66417725148d8b121842e90479 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -35,6 +35,13 @@ limitations under the License. namespace tensorflow { class XlaAllocator; +// Struct that represents a possibly-absent Tensor. +struct OptionalTensor { + string name; // A descriptive name + bool present = false; // Is the tensor present? + Tensor value; // If present, what is the Tensor's value? +}; + // Takes a snapshot of the values of resource variable arguments, whose indices // are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is @@ -139,6 +146,13 @@ class XlaComputationLaunchContext { bool allocate_xla_tensors, bool use_multiple_streams); + // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch + // op. + static Status BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args); + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. // @@ -223,17 +237,6 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; -// Exposed in this header file for microbenchmarking purposes, but this is an -// internal implementation detail. -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -xla::ScopedShapedBuffer ExtractSubShapedBuffer( - xla::ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator); -} // namespace internal - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc deleted file mode 100644 index a45932403ec1760d6b985d5357fd6d84fbf257a2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Contains microbenchmarks for performance critical functions in -// xla_launch_util.cc. - -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs -// (cardinality of each non-leaf node's children). -void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { - tensorflow::testing::StopTiming(); - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); - for (int i = 0; i < depth; ++i) { - std::vector shapes(fan_out, shape); - shape = xla::ShapeUtil::MakeTupleShape(shapes); - } - xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, - /*device_ordinal=*/0); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - // Extract a buffer from approximately the middle of the first level of the - // tree. - (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) - .release(); - } -} - -BENCHMARK(BM_ExtractSubBuffer) - ->ArgPair(1, 4) - ->ArgPair(1, 8) - ->ArgPair(1, 32) - ->ArgPair(1, 64) - ->ArgPair(1, 128) - ->ArgPair(1, 256) - ->ArgPair(1, 512) - ->ArgPair(2, 4) - ->ArgPair(2, 8) - ->ArgPair(2, 32) - ->ArgPair(2, 64) - ->ArgPair(2, 128); - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - tensorflow::testing::RunBenchmarks(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 6f8b198262dfb46b3fd76c52b5c005778cb906eb..d1f7f754c8338487557eda512c56be34c9e958b7 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -43,11 +43,10 @@ namespace tensorflow { } } -Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, +Status XlaTensor::AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal) { - xla::Shape on_host_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); xla::Shape on_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6d7a6fd66c80f6b8c29ad7adb4c9ae8505f5ed81..77e80aa2527ecc2221ac61f7b7e6ebcce0982931 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -50,7 +50,7 @@ class XlaTensor { // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // is replaced and the managed memory deallocated. - Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal); // Some Tensors can have complex on-device shapes, including tuple shapes. To diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6945de1eda1693d7bb870b3310bd7366b364aaaa..6b8e6bba1e1bbfd773141d33721e4d7e30420a11 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -470,12 +470,12 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", + "//tensorflow/python/ops/signal", ], ) @@ -837,8 +837,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - # Stack ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 4e6dd6abfc9cdbabbbcdf0734be828f0aa28683b..332381c59eed06d5697e58efb1d8fa2b6ef604d2 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + import numpy as np from tensorflow.compiler.tests import xla_test @@ -967,7 +969,7 @@ class BinaryOpsTest(xla_test.XLATestCase): self._testBinary( array_ops.expand_dims, np.array([42], dtype=dtype), - np.int32(0), + np.array([0], dtype=np.int64), expected=np.array([[42]], dtype=dtype)) self._testBinary( array_ops.expand_dims, @@ -994,15 +996,21 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[[1, 2], [3, 4]]], dtype=dtype), np.int32(3), expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.array([2], dtype=np.int64), + expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) def testPad(self): - for dtype in self.numeric_types: + for dtype, pad_type in itertools.product( + self.numeric_types, [np.int32, np.int64]): self._testBinary( array_ops.pad, np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[1, 2], [2, 1]], dtype=np.int32), + [[1, 2], [2, 1]], dtype=pad_type), expected=np.array( [[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], @@ -1016,7 +1024,7 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[0, 3], [2, 1]], dtype=np.int32), + [[0, 3], [2, 1]], dtype=pad_type), expected=np.array( [[7, 7, 1, 2, 3, 7], [7, 7, 4, 5, 6, 7], diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b3e13fbaa6b33bdaa1be123be558059e96de282e..e92afd5d6feb42ece233ee521e3a796c4bc3914a 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -24,10 +24,10 @@ import numpy as np import scipy.signal as sps from tensorflow.compiler.tests import xla_test -from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import signal from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 561715ee1c3e0db37169cfd3fb431c0872987d75..6f51ae33a1b0fc8670ddf0cacb03a3b5a9176a91 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -593,6 +593,67 @@ class LazyCompilationTest(test.TestCase): self.assertFalse( InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + def testIsMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Make the cluster go megamorphic by running it with lots of shape + # signatures where the cluster is executed with each signature only a few + # times. Then check that we don't compile the cluster ever again. + + for shape in range(10, 50): + for _ in range(0, 49): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 50): + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + + def testIsNotMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Run the cluster with lots of shape signatures, but in a way that it + # isn't megamorphic (i.e. each shape signature sees a lot of executions). + # Then check that the cluster has not been marked as megamorphic. + + for shape in range(10, 50): + for _ in range(0, 1000): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 10): + sess.run(y, feed_dict={x: [0.] * 60}) + + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index cfccf5f3d2a0a3f2910b2ac1c2747381b172a685..a6b58020126a3297944f199e99b0801387615564 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2466,20 +2466,21 @@ TEST_F(OpTest, Pack) { }); } -// TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); - // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. - // DataType tpaddings = Choose({DT_INT32, DT_INT64}); - DataType tpaddings = DT_INT32; + DataType tpaddings = Choose({DT_INT32, DT_INT64}); std::vector paddings_vec; - std::uniform_int_distribution distribution(0, 7); for (int i = 0; i < t_dims.size(); ++i) { - paddings_vec.push_back(distribution(generator())); - paddings_vec.push_back(distribution(generator())); + std::uniform_int_distribution pad_distribution(0, t_dims[i]); + int pad_size = pad_distribution(generator()); + std::uniform_int_distribution lower_distribution(0, pad_size); + int low_pad_size = lower_distribution(generator()); + paddings_vec.push_back(low_pad_size); + paddings_vec.push_back(pad_size - low_pad_size); + t_dims[i] -= pad_size; } Tensor paddings; CHECK( diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/compiler/tests/resampler_ops_test.py index d05554fdb681a7b60f97a76b8fa33e4dfe6808d8..f87ac3360c905d7956ab3716c47d42765949774d 100644 --- a/tensorflow/compiler/tests/resampler_ops_test.py +++ b/tensorflow/compiler/tests/resampler_ops_test.py @@ -37,7 +37,7 @@ class ResamplerOpsTest(xla_test.XLATestCase): out = sess.run(resampled, {input_image: image_np, warp: warp_np}) self.assertAllCloseAccordingToType( - expected, out, half_rtol=1e-2, bfloat16_rtol=3e-2) + expected, out, rtol=5e-3, half_rtol=1e-2, bfloat16_rtol=3e-2) def _assertBackwardOpMatchesExpected(self, input_np, warp_np, grad_output_np, expected_grad_data, expected_grad_warp): diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dd2c252d383bca9c59033ac07e442b487e4975a6..77cdeac8168aa71555955b141852587d62ab59d3 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" + def testWriteEmptyShape(self): + # Verifies that we can pass an uninitialized variable with an empty shape, + # assign it a value, and successfully return it. + for dtype in self.numeric_types: + with self.test_session() as sess, self.test_scope(): + zeros = np.zeros([3, 0], dtype=dtype) + v = resource_variable_ops.ResourceVariable(zeros) + p = array_ops.placeholder(dtype) + x = v.assign(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(zeros, sess.run(y, {p: zeros})) + def testOneWriteOneOutput(self): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5fc9a352ff930c7d281ec5c52168580e453c04b0..e0171415492658a76b25167107e01300ee4bde88 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -166,6 +166,7 @@ cc_library( "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", + "xla_expression.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", @@ -180,6 +181,7 @@ cc_library( "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", + "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", @@ -194,6 +196,7 @@ cc_library( ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -217,6 +220,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -362,8 +366,12 @@ tf_cc_test( tf_cc_test( name = "xla_compiler_test", - srcs = ["xla_compiler_test.cc"], + srcs = [ + "xla_compiler_test.cc", + "xla_expression_test.cc", + ], deps = [ + ":common", ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", @@ -386,6 +394,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -435,7 +444,7 @@ cc_library( "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc index a6c908ba011afb90fabacc855df8c6afbb35d254..2eb1f8cd849b67922f94cfe3f88456b0d6beeaf8 100644 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ b/tensorflow/compiler/tf2xla/dump_graph_flags.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" @@ -41,7 +41,7 @@ static void AllocateFlags() { "Path prefix to which graphs dumped during debugging should be " "written."), }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); + xla::ParseFlagsFromEnv(*flag_list); } // Append to *append_to flag definitions associated with the XLA bridge's diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f818d80022da0bad851c896f2714c15b20b22195..9ef9f49f422ec4dfaf538ac3c0754ba3609d3f88 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -242,23 +242,20 @@ Status FunctionalizeControlFlowPass::Run( continue; } const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - bool modified; - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + bool modified; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 706ed4f5bbfac60de4653cc8c326214cd4d8d886..efb75749722893100494e089c0beb96944e9f1d4 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -51,12 +52,11 @@ namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { - auto builder = ctx->builder(); auto client = ctx->compiler()->client(); - std::vector compile_time_constant_flags(expressions.size()); + std::vector arg_must_be_compile_time_constant(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); @@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); - if (arg.type == DT_RESOURCE) { - return errors::InvalidArgument( - "Resource as function argument is not yet implemented."); - } else if (expressions[i]->has_constant_value()) { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); - } else if (compile_time_constant_flags[i]) { - arg.kind = XlaCompiler::Argument::kConstant; - TF_RET_CHECK(expressions[i]->resource() == nullptr) - << "Input with resource is not yet implemented."; - TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( - expressions[i]->handle())); - TF_ASSIGN_OR_RETURN(auto literal, - client->ComputeConstant(constant_graph)); - TF_RETURN_IF_ERROR( - LiteralToHostTensor(literal, arg.type, &arg.constant_value)); - } else { - arg.kind = XlaCompiler::Argument::kParameter; + switch (expressions[i]->kind()) { + case XlaExpression::Kind::kConstant: + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + break; + case XlaExpression::Kind::kXlaOp: + if (arg_must_be_compile_time_constant[i]) { + TF_ASSIGN_OR_RETURN(absl::optional value, + expressions[i]->ResolveConstant(client)); + if (!value.has_value()) { + return errors::InvalidArgument( + "Argument to function must be a compile-time constant, but " + "unable to resolve argument value to a constant."); + } + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *value; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + break; + case XlaExpression::Kind::kResource: + return errors::Unimplemented( + "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument("Invalid function argument"); } } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 276d744c096f8996c774964204feaa3762bdb844..2db2514397deca39e6874cf994532a20d2186316 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel { } const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; - if (arg.resource() != nullptr) { - ctx->SetResourceOutput(0, arg.resource()); - } else if (arg.has_constant_value()) { - ctx->SetConstantOutput(0, arg.constant_value()); - } else { - ctx->SetOutput(0, arg.handle()); - } + OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, + errors::InvalidArgument("Invalid/missing argument expression")); + ctx->SetOutputExpression(0, arg); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 9fa57b76f8e3649c03fe41f39638b88cb065ed0e..c022284fec6bc91951170e243ea3609c8d5d0c43 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), errors::InvalidArgument("In[", i, "] must be a vector.", in_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); - - BCast::Vec vec; - for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(literal.Get({i})); - } - shapes.push_back(vec); + std::vector vec; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec)); + + shapes.push_back(BCast::Vec(vec.begin(), vec.end())); } BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e28755dd73bf6f8e1518dda2494cade79b7db22e..cd7c7f4a82df7a65829787efcb1fd2f77870e945 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -45,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_tensor_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); - // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = literal.Get({}); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_tensor_shape.DebugString())); + int64 concat_dim; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); std::vector values; std::vector shapes; @@ -63,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel { const TensorShape& input_shape = shapes[0]; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(ctx, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -75,14 +72,11 @@ class ConcatBaseOp : public XlaOpKernel { // elements. std::vector input_data; int output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; - const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( - ctx, - in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), + ctx, in_shape.dims() == input_dims, errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", input_shape.DebugString(), " vs. shape[", i, @@ -131,11 +125,10 @@ class ConcatOffsetOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_shape = ctx->InputShape(0); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_shape.DebugString())); for (int i = 1; i < ctx->num_inputs(); ++i) { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), errors::InvalidArgument("input ", i, @@ -162,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel { // [0, 5, 0, 0] const int32 N = ctx->num_inputs() - 1; const TensorShape inp0_shape = ctx->InputShape(1); - xla::Literal inp0_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); - const int64 dims = inp0_shape.num_elements(); + std::vector inp0_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims)); + const int64 inp0_rank = inp0_shape.num_elements(); - xla::Literal concat_dim_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = concat_dim_literal.Get({}); + int64 cdim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); - VLOG(1) << "ConcatOffset " << cdim << "," << dims; - int32 axis = cdim < 0 ? cdim + dims : cdim; - OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), + VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; + int32 axis = cdim < 0 ? cdim + inp0_rank : cdim; + OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), errors::InvalidArgument("Concat dim is out of range: ", axis, - " vs. ", dims)); + " vs. ", inp0_rank)); int32 offset = 0; for (int i = 0; i < N; ++i) { const TensorShape inp_shape = ctx->InputShape(1 + i); - OP_REQUIRES(ctx, dims == inp_shape.num_elements(), - errors::InvalidArgument("input ", i, " should contain ", dims, - " elements, but got ", + OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), + errors::InvalidArgument("input ", i, " should contain ", + inp0_rank, " elements, but got ", inp_shape.num_elements())); - xla::Literal inp_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); + std::vector inp_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims)); - Tensor out_constant(DT_INT32, TensorShape({dims})); + Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); auto out_vec = out_constant.vec(); - for (int64 j = 0; j < dims; ++j) { + for (int64 j = 0; j < inp0_rank; ++j) { if (j == axis) { out_vec(j) = offset; - offset += inp_literal.Get({j}); + offset += inp_dims[j]; } else { - const int32 inp0_element = inp0_literal.Get({j}); - const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES(ctx, (inp0_element == inp_element), + const int32 inp0_element = inp0_dims[j]; + const int32 inp_element = inp_dims[j]; + OP_REQUIRES(ctx, inp0_element == inp_element, errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", inp0_element, " vs. ", inp_element)); diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 2628ef8e2454976aeff3859fa5dc1d8e106f32e1..dff8af800229b9605bb93e0498bc5e5cf012f244 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); - if (proto_.dtype() == DT_STRING) { - LOG(WARNING) << "Not computing Const of type DT_STRING"; - ctx->SetInvalidOutput(0); - return; - } xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e9bdb15aa0c57fd95530798f87c68e2e63e84e1d..35e0625dbb0d4c696d36cce642d6f50f1d220c45 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -33,39 +34,20 @@ class FillOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // The output of this Op is a tensor of shape 'dims_shape' with each // element set to the scalar 'dims_literal'. - const TensorShape dims_shape = ctx->InputShape(0); - const TensorShape value_shape = ctx->InputShape(1); + const TensorShape dims_shape = ctx->InputShape("dims"); + const TensorShape value_shape = ctx->InputShape("value"); OP_REQUIRES( - ctx, IsLegacyVector(dims_shape), + ctx, TensorShapeUtils::IsVector(dims_shape), errors::InvalidArgument("dims must be a vector of int32, got shape ", dims_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(value_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape), errors::InvalidArgument("value must be a scalar, got shape ", value_shape.DebugString())); - // Evaluate the 'dims' constant input, reshaping to a vector if it - // was a 'legacy' vector (secretly a scalar). - xla::Literal dims_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped( - 0, {dims_shape.num_elements()}, &dims_literal)); - // Convert the dims literal into a vector that we can pass to - // XlaBuilder. - std::vector broadcast; - broadcast.reserve(dims_literal.shape().dimensions(0)); - for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(dims_literal.Get({i})); - } - // Look up the value input, reshaping to a scalar if it was a - // 'legacy' scalar (secretly a vector). - xla::XlaOp data = ctx->Input(1); - if (value_shape.dims() > 0) { - CHECK_EQ(value_shape.dims(), 1); - data = xla::Reshape(data, {}); - } - // Emit the actual computation, which broadcasts the scalar to the - // desired shape. - auto result = xla::Broadcast(data, broadcast); + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims)); + auto result = xla::Broadcast(ctx->Input("value"), dims); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index d069373086a6dcd6e7901abfc63d851a731da321..e310db2162da0997204f85bc3ca42e7b0460e1e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -48,9 +48,8 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // We require that the dimension argument is a constant, since it lets us // dispatch to a specialized custom-call function without any run-time // overhead, when compiling ahead-of-time. - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = literal.Get({}); + int64 dim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 4833a9662dd12ca72b5715373b549af105625d45..f6b8534f4d7c537e5b708ee000e00cb92123584b 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel { for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); - TF_ASSIGN_OR_RETURN(int64 lhs_padding, - pad_literal.GetIntegralAsS64({dimno, 0})); - TF_ASSIGN_OR_RETURN(int64 rhs_padding, - pad_literal.GetIntegralAsS64({dimno, 1})); + int64 lhs_padding = pad_literal.Get({dimno, 0}); + int64 rhs_padding = pad_literal.Get({dimno, 1}); int64 dim_size = original_shape.dimensions(dimno); // Padding amounts on each side must be no more than the size of the @@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); @@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel { TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::XlaBuilder* b = ctx->builder(); - auto in0 = ctx->Input(0); + auto in0 = ctx->Input("input"); xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 3f5445b4821b0918f3b220ecfe2be20bccb33dc2..36ea70ac392ff18fb52d400efa886533f8335eba 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel { explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); const int dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); - if (fixed_dims == 0) { + xla::XlaOp input = ctx->Input("input"); + if (dims == 0) { // Tensor is rank 0. Return it unchanged. - ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(0, input); return; } - // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::PaddingConfig config; - for (int i = 0; i < fixed_dims; ++i) { + for (int i = 0; i < dims; ++i) { auto* dim = config.add_dimensions(); - int before = pad_literal.Get({i, 0}); - int after = pad_literal.Get({i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument( "Paddings must be non-negative: ", before, " ", after)); @@ -73,12 +70,13 @@ class PadOp : public XlaOpKernel { // PadV2 added a "constant_values" input that indicates the pad value. xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), - errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")), + errors::InvalidArgument("constant_values must be a scalar.")); + ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(input, zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 47a4eac20669c225a653bd3f1f00eeafd0845a42..fa1b6b91710f5507f41f3f69b0715398ae879aaf 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); const TensorShape sizes_shape = ctx->InputShape(1); // Preliminary validation of sizes. - OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e172c649325adb6f7761ce0be141f21e8d545bc1..6970dd0a00641c9f88571561501fb3454fb3eab3 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -46,61 +47,8 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::XlaOp input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - DataType input_type = ctx->input_type(0); - XlaContext& tc = XlaContext::Get(ctx); - - if (input_type == DT_RESOURCE) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - ctx->SetStatus(tc.AddResourceRetval(index_, resource)); - return; - } - - auto is_constant = ctx->builder()->IsConstant(input); - if (!is_constant.ok()) { - ctx->SetStatus(is_constant.status()); - return; - } - - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); - } else { - TensorShape shape = ctx->InputShape(0); - ctx->SetStatus(is_constant.status()); - TensorShape representation_shape; - if (tc.is_entry_computation()) { - xla::StatusOr shape_or_status = - tc.RepresentationShape(shape, ctx->input_type(0)); - if (!shape_or_status.ok()) { - ctx->SetStatus(shape_or_status.status()); - return; - } else { - representation_shape = shape_or_status.ValueOrDie(); - } - } else { - representation_shape = shape; - } - - xla::XlaOp output = input; - if (tc.is_entry_computation()) { - output = xla::Reshape(input, representation_shape.dim_sizes()); - } else { - // The core from which a return value is returned depends on the - // device assignment of the input to the retval. Since we can't change - // the device assignment of "input" at this point, we must always - // introduce an operator here, even if the shape does not change. - // TODO(b/76097077): propagate device assignments onto arguments and - // return values of functions, and then reshape unconditionally. - output = - xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); - } - tc.AddRetval(index_, dtype_, shape, output); - } + XlaContext& xla_context = XlaContext::Get(ctx); + xla_context.SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 56b80cb4a299c07157a166208b96ad369075aa83..2ceadaf79c5cef35ad50aa84a0d66a46527a6458 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel { } // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); - std::vector revdims(x_shape.dims()); - std::copy(lax.data().begin(), lax.data().end(), - revdims.begin()); - std::vector dimensions; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax)); + std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { - if (revdims[d]) { + if (lax.Get({d})) { dimensions.push_back(d); } } diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 379f4aeb0fc7bbfff59696726f5af231b1294c49..60b011ba6d9b64a89e4228ba2a213f72b67a462d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -30,31 +30,6 @@ limitations under the License. namespace tensorflow { namespace { -template -Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = literal.Get({}); - return Status::OK(); -} - -Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - switch (literal.shape().element_type()) { - case xla::S32: - *value = literal.Get({}); - break; - case xla::S64: - *value = literal.Get({}); - break; - default: - return errors::InvalidArgument("Invalid argument type for argument", - index); - } - return Status::OK(); -} - // The type-specific part of the implementation of Range. template xla::StatusOr CreateRangeTensor( @@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel { const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape limit_in_shape = ctx->InputShape(1); const TensorShape delta_in_shape = ctx->InputShape(2); - OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape), errors::InvalidArgument("limit must be a scalar, not shape ", limit_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape), errors::InvalidArgument("delta must be a scalar, not shape ", delta_in_shape.DebugString())); xla::Literal start, limit, delta; @@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel { explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape start_in_shape = ctx->InputShape(0); - const TensorShape stop_in_shape = ctx->InputShape(1); - const TensorShape num_in_shape = ctx->InputShape(2); + const TensorShape start_in_shape = ctx->InputShape("start"); + const TensorShape stop_in_shape = ctx->InputShape("stop"); + const TensorShape num_in_shape = ctx->InputShape("num"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); @@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel { DataType type = ctx->input_type(0); int64 num; - OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num)); OP_REQUIRES(ctx, num > 0, errors::InvalidArgument("Requires num > 0: ", num)); Tensor out_constant(type, TensorShape({num})); + xla::Literal start_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal)); + xla::Literal stop_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal)); + switch (type) { case DT_FLOAT: { - float start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + float start = start_literal.GetFirstElement(); + float stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; @@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel { break; } case DT_DOUBLE: { - double start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + double start = start_literal.GetFirstElement(); + double stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 37b026aeb058b464acd74264766f187b787914aa..12830816ec16c9797f0fe4d8f3f13f5a8176161d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { @@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel { explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dim_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape dim_shape = ctx->InputShape("dim"); - // TODO(phawkins): the standard implementation of ExpandDimsOp seems to - // accept legacy scalars, even when they should be forbidden by the graphdef - // version. - OP_REQUIRES(ctx, dim_shape.num_elements() == 1, + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); + OP_REQUIRES(ctx, dims.size() == 1, errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); - - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - - int dim = literal.data()[0]; + int dim = dims[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -148,7 +144,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 34980ead81815c2818a259e096148fcce9c9a3b1..88da64e5a217a0c026106f03cb26958f6738446c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" @@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel { OP_REQUIRES( ctx, - IsLegacyVector(begin_tensor_shape) && - IsLegacyVector(size_tensor_shape) && + TensorShapeUtils::IsVector(begin_tensor_shape) && + TensorShapeUtils::IsVector(size_tensor_shape) && begin_tensor_shape.num_elements() == input_shape.dims() && size_tensor_shape.num_elements() == input_shape.dims(), errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 230a343f7966f19cda44991a747287ba675fca4c..7a0e240400b344ab25743997ce3baad81bd5f476 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); - const TensorShape index_shape = ctx->InputShape(0); + const TensorShape split_dim_shape = ctx->InputShape("split_dim"); const TensorShape input_shape = ctx->InputShape(1); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - - int32 split_dim_orig; - if (index_shape.dims() == 0) { - split_dim_orig = literal_index.Get({}); - } else { - OP_REQUIRES( - ctx, index_shape.dims() == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - OP_REQUIRES( - ctx, index_shape.dim_size(0) == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - split_dim_orig = literal_index.Get({0}); - } + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(split_dim_shape), + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_shape.dims())); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() : split_dim_orig; OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), @@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel { // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; - std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); OP_REQUIRES(ctx, split_size_shape.dims() == 1 && @@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel { split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); // Get the dimension of this split. - xla::Literal split_size_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + std::vector split_sizes; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes)); for (int i = 0; i < num_split; ++i) { - int slice_size; - slice_size = split_size_literal.Get({i}); + int64 slice_size = split_sizes[i]; if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, @@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel { i)); neg_one_dim = i; } else { - split_sizes_vec[i] = slice_size; total_split_size += slice_size; } } @@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel { total_split_size)); if (neg_one_dim >= 0) { - split_sizes_vec[neg_one_dim] = + split_sizes[neg_one_dim] = input_shape.dim_size(split_dim) - total_split_size; } @@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel { std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); - int slice_size = split_sizes_vec[i]; + int slice_size = split_sizes[i]; output_shape.set_dim(split_dim, slice_size); // Slice out the ith split from the split dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d79cdad9fa2dabe1e236741955499b845064148f..7b96b43ad834c28aa0283c5ef4ac516618ca5134 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp); +REGISTER_XLA_OP( + Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(), + StackOp); class StackPushOp : public XlaOpKernel { public: @@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); }; -REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); +REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp); class StackPopOp : public XlaOpKernel { public: @@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); }; -REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); +REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp); class StackCloseOp : public XlaOpKernel { public: @@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); }; -REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); +REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 7b2cd5a5b08d80284d172c2ed5d6be4c355e76e0..e1c764f3d5c28cf0d812519e4a16786e1f2d3a3a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" @@ -44,7 +45,7 @@ class TileOp : public XlaOpKernel { const TensorShape multiples_shape = ctx->InputShape("multiples"); OP_REQUIRES( - ctx, IsLegacyVector(multiples_shape), + ctx, TensorShapeUtils::IsVector(multiples_shape), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples_shape.DebugString())); OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 48a211942d7c4405bf68189e641ee184db36b0ba..c9b324a243e4cc3ec64daa3ca0d285336a0d0154 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel { : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape perm_tensor_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("x"); + const TensorShape perm_tensor_shape = ctx->InputShape("perm"); // Preliminary validation of sizes. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), @@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel { ". But input(1) is a vector of size ", perm_tensor_shape.num_elements())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); - - std::vector perm(dims); - std::copy(literal.data().begin(), literal.data().end(), - perm.begin()); + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm)); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { - const int32 d = perm[i]; + const int64 d = perm[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); @@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel { xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - transposed = ctx->Input(0); + transposed = ctx->Input("x"); } else { - transposed = xla::Transpose(ctx->Input(0), transposed_order); + transposed = xla::Transpose(ctx->Input("x"), transposed_order); } // Conjugate the transposed result if this is ConjugateTransposeOp. diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0bdfc05726105e2d18362a691cbe2aab00bf77f3..a0ea6422d732b00fc1b8cf855d9c9ad603b87c82 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); XLAJIT_MAKE_UNARY(Neg, -x); -// Implements Banker's rounding: numbers that are equidistant between two -// integers are rounded towards even. -xla::XlaOp RoundToEven(xla::XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); - - auto round_val = xla::Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); -} - -XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 20103ec3ae00b57723e05326dbbb1b0f6e1a671a..67d08290033361f16dfff42b06af9b253e84963a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal)); + return literal.Clone(); +} + Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { xla::Shape xla_shape; diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 1db7470ee2a839099454b772d4833492e033bc92..a153dddee6127ff9c0858220f2d8a735ab3f0e19 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,11 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); + +// Returns a Literal with the contents of 'host_tensor', backed by its own +// storage (i.e., not reusing 'host_tensor's buffers.) +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor); + // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // owned by 'host_tensor', but is mutable via the xla::Literal methods. Status HostTensorToMutableBorrowingLiteral( diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 8b559c87506a6b519e2ad1d1bf22ab30c0ff161d..c9f486edc8d30954619db0967c988fe8e26938de 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package( default_visibility = [ "//learning/deepmind/public/wavenet/python:__subpackages__", + "//learning/deepmind/research/alphastar:__subpackages__", "//learning/tfx:__subpackages__", "//tensorflow:internal", ], diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index cb7843850c352eee2e55baf52a0c4445dc861d7b..ddb284966eeb97cc7c9d3ed77fb313e567975e59 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto( "XLACompilationDevice::MakeTensorFromProto should not be called"); } -XlaExpression::XlaExpression() = default; - -void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index a6e78825334fec748be5fee80669649df699d2fb..de6a3356e05d8ab45c269d7c6c653853d2c63a79 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" @@ -38,8 +35,8 @@ class XlaCompilationAllocator; // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that -// backs each Tensor with metadata indicating the computation the -// Tensor represents. +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. // // We deliberately don't register a device factory because we *never* // want placement to put Ops on a compilation device. The device is created @@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// A XlaExpression wraps an XLA computation. Each Tensor on an -// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the XlaOp. Each -// expression is either a constant, or a function of previously-compiled -// expressions. -class XlaExpression { - public: - XlaExpression(); - - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::XlaOp& h); - const xla::XlaOp& handle() const { return handle_; } - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - void set_resource(XlaResource* resource) { resource_ = resource; } - XlaResource* resource() const { return resource_; } - - private: - // The XLA handle of the expression's computation. - xla::XlaOp handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_ = false; - Tensor constant_value_; - - XlaResource* resource_ = nullptr; // Not owned. -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e177a5f07f5607a0f9de75e6a999ee492cd9db4f..a08d030ce710bdb97910c01a64f80199fc10d649 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,10 +36,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -48,7 +51,7 @@ namespace { // Checks that arguments `args` match types `types`. Status CheckSignature(const DataTypeVector& types, - const std::vector& args) { + absl::Span args) { if (args.size() != types.size()) { return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); @@ -63,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } +// Uses the _Arg and _Retval nodes in the graph to determine a core assignment +// for each argument and return value. +xla::StatusOr, std::map>> +ComputeArgAndRetvalCores(const Graph& graph) { + auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + std::map arg_cores; + std::map retval_cores; + for (const Node* n : graph.nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Arg index"; + arg_cores[index] = core; + } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Retval index"; + TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); + retval_cores[index] = core; + } + } + return std::make_pair(std::move(arg_cores), std::move(retval_cores)); +} + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status status; + auto step_container = absl::make_unique( + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return Status::OK(); +} + +// Builds the XLA computation. +// - `args` is the list of input arguments +// - `retvals` is the list of retvals produced by _Retval operators, in index +// order. +// - `args_core` and `retval_cores` are mapping from arg/return indices to core +// assignments. +// - If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// - Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// - `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a ResourceUpdate, whose `index` is the index of a +// resource variable argument to the computation to be updated, and `type` is +// the type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& retvals, + const std::map& arg_cores, const std::map& retval_cores, + const std::vector>& resources, + std::unique_ptr token_output, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, + std::vector* resource_updates) { + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); + + // Builds a no-op XLA computation. We need to set the sharding of outputs, but + // cannot change the sharding of the existing output op. To do this, we build + // a new identity op to which shardings can be applied. + auto identity_op = [builder](xla::XlaOp op) { + return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); + }; + + std::vector elems; + elems.reserve(retvals.size()); + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + const XlaExpression& retval = retvals[i]; + output.type = retval.dtype(); + switch (retval.kind()) { + case XlaExpression::Kind::kConstant: + output.is_constant = true; + output.constant_value = retval.constant_value(); + output.shape = output.constant_value.shape(); + break; + + case XlaExpression::Kind::kXlaOp: { + output.is_constant = false; + TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); + xla::XlaOp value = retval.handle(); + auto it = retval_cores.find(i); + xla::XlaScopedShardingAssignment assign_sharding( + builder, it == retval_cores.end() + ? absl::optional() + : xla::sharding_builder::AssignDevice(it->second)); + if (shape_representation_fn) { + // If there is a shape representation function, reshape the output + // tensor to the shape given by the representation shape function. + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( + output.shape, output.type)); + value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + } else if (it != retval_cores.end()) { + // Apply the sharding to the output, if there is a core assignment. + value = identity_op(value); + } + elems.push_back(value); + break; + } + + case XlaExpression::Kind::kResource: + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); + output.shape = retval.resource()->shape(); + break; + + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument( + "Invalid expression returned by computation. " + "This probably means a return value was not set."); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + DCHECK_LT(resource->arg_num(), args.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + auto it = arg_cores.find(resource->arg_num()); + const int core = it == arg_cores.end() ? -1 : it->second; + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::XlaScopedShardingAssignment assign_sharding( + builder, core == -1 ? absl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Ensures the correct sharding is applied to the output. + handle = identity_op(handle); + + elems.push_back(handle); + } + } + + // If we have token output, append it as the last one. + if (token_output) { + elems.push_back(*token_output); + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + } // namespace bool XlaCompiler::Argument::operator==( @@ -83,6 +320,39 @@ bool XlaCompiler::Argument::operator==( return constant_value.tensor_data() == other.constant_value.tensor_data(); } +string XlaCompiler::Argument::HumanString() const { + string common; + if (!name.empty()) { + common = absl::StrCat(" name=", name); + } + absl::StrAppend(&common, " type=", DataTypeString(type), + " shape=", shape.DebugString()); + switch (kind) { + case kInvalid: + return "invalid"; + case kConstant: + return absl::StrCat("kind=constant", common, + " value=", constant_value.DebugString()); + case kResource: { + string output = absl::StrCat("kind=resource", common, " resource_kind=", + XlaResource::KindToString(resource_kind), + " initialized=", initialized); + if (tensor_array_size >= 0) { + absl::StrAppend(&output, " tensor_array_size=", tensor_array_size); + } + if (!tensor_array_gradients.empty()) { + absl::StrAppend(&output, " tensor_array_gradients=", + absl::StrJoin(tensor_array_gradients, ",")); + } + return output; + } + case kParameter: + return absl::StrCat("kind=parameter", common); + case kToken: + return absl::StrCat("token", common); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), @@ -110,8 +380,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = [](const TensorShape& shape, - DataType type) { return shape; }; + options_.shape_representation_fn = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -171,15 +446,16 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - std::vector args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction( + const XlaCompiler::CompileOptions& options, const NameAttrList& function, + absl::Span args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; - auto it = cache_.find({function_id, args}); + const std::vector arg_vector(args.begin(), args.end()); + auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; return Status::OK(); @@ -212,14 +488,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kArgOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kRetOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -235,7 +513,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; - cache_[{function_id, args}] = *result; + cache_[{function_id, arg_vector}] = *result; return Status::OK(); } @@ -247,25 +525,24 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape; if (is_entry_computation) { TF_ASSIGN_OR_RETURN( - shape, options_.shape_representation_fn(arg.shape, arg.type)); + *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); } else { - shape = arg.shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); } - return TensorShapeToXLAShape(arg.type, shape, xla_shape); + return Status::OK(); } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, - options_.shape_representation_fn(arg.shape, arg.type)); - return TensorShapeToXLAShape(arg.type, representation_shape, - xla_shape); + TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( + arg.shape, arg.type)); + + return Status::OK(); } case XlaResource::kTensorArray: { if (arg.tensor_array_size < 0) { @@ -314,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } -namespace { - -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the - // resource manager takes ownership via Create, and unrefs via Cleanup. We - // explicitly add a reference to ensure the refcount at entry is maintained at - // all exit points; Create and Cleanup are always called in this function. - // - // The Executor requires us to use ScopedStepContainer. We wrap it in a - // unique_ptr so we can capture the cleanup status in the end. - xla_context->Ref(); - Status status; - auto step_container = absl::make_unique( - step_id, [&status, device](const string& name) { - status = device->resource_manager()->Cleanup(name); - }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); - - GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); - TF_RETURN_IF_ERROR(graph_compiler.Compile()); - // Explicitly clean up the step container, to capture the cleanup status. - step_container.reset(); - return Status::OK(); -} - -// Builds the XLA computation. -// `args` is the list of input arguments, `retvals` is the list of retvals -// produced by _Retval operators, in index order. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - std::unique_ptr token_output, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else if (retval.resource() != nullptr) { - output.is_constant = false; - output.input_index = retval.resource()->arg_num(); - } else { - output.is_constant = false; - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - // Attach a common operator name as metadata. This has no semantic effect — it - // merely makes the HLO graph more readable when visualized via TensorBoard, - // since TensorBoard forms groups out of operators with similar names. - xla::OpMetadata retval_metadata; - retval_metadata.set_op_name("XLA_Retvals"); - builder->SetOpMetadata(retval_metadata); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = - modified || - !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); - } - } - - // If we have token output, append it as the last one. - if (token_output) { - elems.push_back(*token_output); - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. We *always* form a tuple here to ensure that - // the output value is the last thing added into the XLA computation, even - // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); - if (!always_return_tuple && elems.size() == 1) { - xla::GetTupleElement(tuple, 0); - } - builder->ClearOpMetadata(); - - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - std::vector* arg_cores, std::vector* arg_expressions, + const std::map& arg_cores, + std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); - *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -504,7 +622,7 @@ Status XlaCompiler::BuildArguments( arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); - arg_expression.set_resource(resource); + arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } @@ -516,7 +634,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kConstant: - arg_expression.set_constant_value(arg.constant_value); + arg_expression = XlaExpression::Constant(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -541,26 +659,6 @@ Status XlaCompiler::BuildArguments( *input_shapes = arg_shapes; } - // Use the _Arg nodes in the graph to resolve core assignments. - for (const Node* n : graph.nodes()) { - if (absl::string_view(n->type_string()) != "_Arg") continue; - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - TF_RET_CHECK(index >= 0 && index < args.size()) - << "_Arg out of bounds: " << index << " vs " << args.size(); - TF_ASSIGN_OR_RETURN( - auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == - xla::OpSharding::Type::OpSharding_Type_MAXIMAL); - const int core = sharding.value().tile_assignment_devices(0); - if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { - (*arg_cores)[index] = core; - } - } - } - // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -576,11 +674,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); for (int64 parameter : *input_mapping) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; + auto it = arg_cores.find(parameter); + const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); + xla::sharding_builder::AssignDevice(core); } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); @@ -589,7 +686,8 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -597,7 +695,8 @@ Status XlaCompiler::BuildArguments( } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -632,14 +731,14 @@ Status XlaCompiler::BuildArguments( // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. if (is_entry_computation) { - arg_expression.set_handle( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression = XlaExpression::XlaOp( + xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); } else { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; case XlaCompiler::Argument::kToken: { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; } case XlaCompiler::Argument::kConstant: @@ -653,46 +752,48 @@ Status XlaCompiler::BuildArguments( } Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, string const& name, - OpKernelContext* ctx, const std::vector& args, - CompilationResult* result) { + const XlaCompiler::CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result) { // TODO(b/74182462): We implement this by creating a new dummy Graph including // _Arg nodes, and let CompileGraph walk it. This could be optimized. std::unique_ptr graph(new Graph(OpRegistry::Global())); Status status; // First create the actual node we care about computing. - Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + Node* main_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); // Create dummy _Arg nodes. Link these to `node` and also via a control // dependency edge to the _SOURCE node. - for (int64 i = 0; i < ctx->num_inputs(); ++i) { + for (int64 i = 0; i < args.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); - Status status = NodeBuilder(name, "_Arg") - .ControlInput(graph->source_node()) - .Attr("T", ctx->input_dtype(i)) - .Attr("index", i) - .Finalize(graph.get(), &node); + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); graph->AddEdge(node, 0, main_node, i); } // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < ctx->num_outputs(); ++i) { + for (int64 i = 0; i < result_types.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); - Status status = NodeBuilder(name, "_Retval") + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) .Input(main_node, i) - .Attr("T", ctx->expected_output_dtype(i)) + .Attr("T", result_types[i]) .Attr("index", i) .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, name, std::move(graph), args, result); + return CompileGraph(options, node_def.name(), std::move(graph), args, result); } namespace { @@ -747,12 +848,38 @@ Status ValidateGraph(const Graph* graph, return Status::OK(); } +// Converts the value of any expressions whose values are known at compile-time +// to constants. +Status ResolveConstantExpressionsToConstants( + xla::Client* client, absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kXlaOp) { + TF_ASSIGN_OR_RETURN(absl::optional constant, + expression.ResolveConstant(client)); + if (constant.has_value()) { + expression = XlaExpression::Constant(*constant); + } + } + } + return Status::OK(); +} + +void ConvertConstantsToExpressions(xla::XlaBuilder* builder, + absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kConstant) { + expression = + XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); + } + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; @@ -774,13 +901,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); - std::vector real_args(args); + std::vector real_args(args.begin(), args.end()); int token_input_index = -1; std::unique_ptr token_output; if (options.add_token_input_output) { @@ -792,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, real_args.push_back(token_arg); } + std::map arg_cores; + std::map retval_cores; + TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), + ComputeArgAndRetvalCores(*graph)); + std::vector arg_expressions; - std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -843,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); + std::vector retvals = context->retvals(); + if (options.resolve_compile_time_constants) { + TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals))); + } else { + ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + } TF_RETURN_IF_ERROR(BuildComputation( - real_args, arg_cores, context->retvals(), context->resources(), - std::move(token_output), options.return_updated_values_for_all_resources, + real_args, retvals, arg_cores, retval_cores, context->resources(), + std::move(token_output), + options.is_entry_computation ? options_.shape_representation_fn + : ShapeRepresentationFn{}, + options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, &result->resource_updates)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2cc603a58016a509fafdf6f95423dd6c0864cce3..63426124686e1b92a3534b7e365b8282008b8455 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" @@ -118,7 +121,7 @@ class XlaCompiler { // The type of the argument. If the argument is a resource, this // is the type of the variable's value, not DT_RESOURCE. - DataType type; + DataType type = DT_INVALID; // The shape of the argument. For: // * a parameter: the shape of the parameter. @@ -155,6 +158,9 @@ class XlaCompiler { std::set tensor_array_gradients; bool operator==(const Argument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -259,8 +265,7 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function(const TensorShape&, - DataType)> + typedef std::function(const TensorShape&, DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. @@ -316,22 +321,23 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - std::vector args, CompilationResult* result); + absl::Span args, + CompilationResult* result); // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result); - // Compiles a single Op, given by an OpKernelContext, into an + // Compiles a single Op, given by `node_def`, into an // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. - Status CompileSingleOp(const CompileOptions& options, string const& name, - OpKernelContext* ctx, - const std::vector& args, + Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result); // Returns the shape of the XLA parameter for an argument 'arg'. @@ -411,7 +417,8 @@ class XlaCompiler { Status BuildArguments(const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, std::vector* arg_cores, + XlaContext* context, + const std::map& arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4ef154f856b9284a6c97f2c3072b198ccfb5e517..aaee208f6349d56f685481977cea55c8dd5e7938 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -1018,9 +1019,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); @@ -1086,9 +1089,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 20e1ee2ddb390edd3a7d881022c68072a69193dc..43095fbb47351617a0de12a088c947106ccaa641 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } -// This is called by the Retval Op to associate a computed value -// with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { - VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); +void XlaContext::SetRetval(int index, const XlaExpression& expression) { + if (retvals_.size() <= index) { + retvals_.resize(index + 1); } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; + retvals_[index] = expression; } -Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal) { - VLOG(1) << "Adding retval index " << retval_index - << " with non-data-dependent tensor to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; - return Status::OK(); -} - -Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { - VLOG(1) << "Adding retval index " << retval_index << " with resource " - << resource->name() << ":" << resource->shape().DebugString() - << " to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - XlaExpression e; - e.set_resource(resource); - retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; - return Status::OK(); -} - -xla::XlaBuilder* XlaContext::builder() { return builder_; } - Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, @@ -133,7 +93,7 @@ Status XlaContext::CreateResource( return Status::OK(); } -xla::StatusOr XlaContext::RepresentationShape( +xla::StatusOr XlaContext::RepresentationShape( const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e97dd67af0ef09ef33dbc7a4d19743b..dbfd344c9bad8a5d05abb6a3b902ed3baebbe02a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -46,9 +46,8 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( + bool allow_cpu_custom_calls, + const std::function( const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. @@ -57,37 +56,19 @@ class XlaContext : public ResourceBase { XlaCompiler* compiler() const { return compiler_; } // Returns the XlaBuilder that Ops use for compiling new expressions. - xla::XlaBuilder* builder(); + xla::XlaBuilder* builder() { return builder_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector& retvals() { return retvals_; } - - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); + const std::vector& retvals() { return retvals_; } - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal); - - // As for Retval, but for return values that are resource handles. - Status AddResourceRetval(int retval_index, XlaResource* resource); + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -105,8 +86,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -140,31 +121,19 @@ class XlaContext : public ResourceBase { // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; - // If true, constant return values are returned as Tensors instead of - // run-time computation outputs. - const bool resolve_compile_time_constants_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - - // A function that describes how the shapes of - // a) argument and return value, for entry computations - // b) variables, for all computations, - // should be represented in XLA. Parameters/return values will be shaped - // according to this function, and reshaped back to/from their declared shapes - // for computations. Must be non-null. - const std::function(const TensorShape&, DataType)>* + // Describes the on-host shapes of parameters and return values. Also see: + // XlaDevice::Options::shape_representation_fn. + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca0309166b7c73d1a5a818091e2a30fa112a4de4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_expression.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() = default; + +XlaExpression XlaExpression::Invalid() { + XlaExpression e; + e.kind_ = Kind::kInvalid; + return e; +} + +XlaExpression XlaExpression::Constant(Tensor value) { + XlaExpression e; + e.kind_ = Kind::kConstant; + e.dtype_ = value.dtype(); + e.constant_value_ = value; + return e; +} + +XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { + XlaExpression e; + e.kind_ = Kind::kXlaOp; + e.dtype_ = dtype; + e.handle_ = value; + return e; +} + +XlaExpression XlaExpression::Resource(XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + return e; +} + +string XlaExpression::HumanString() const { + switch (kind_) { + case Kind::kInvalid: + return "invalid"; + case Kind::kConstant: + return "constant"; + case Kind::kXlaOp: + return "xla_op"; + case Kind::kResource: + return "resource"; + } +} + +xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + switch (kind_) { + case Kind::kConstant: { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToBorrowingLiteral(constant_value_, &literal)); + return xla::ConstantLiteral(builder, literal); + } + case Kind::kXlaOp: + if (builder != handle_.builder()) { + return errors::InvalidArgument( + "Mismatched builders in XlaExpression::AsXlaOp"); + } + return handle_; + default: + return errors::InvalidArgument("AsXlaOp called on XlaExpression: ", + HumanString()); + } + }); +} + +xla::StatusOr> XlaExpression::ResolveConstant( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: + return {constant_value()}; + case Kind::kXlaOp: + break; + case Kind::kResource: + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveConstant called on XlaExpression: ", HumanString()); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + handle().builder()->IsConstant(handle())); + if (!is_constant) return {absl::nullopt}; + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildConstantSubGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor)); + return {tensor}; +} + +xla::StatusOr XlaExpression::GetShape() const { + switch (kind_) { + case Kind::kConstant: + return constant_value().shape(); + case Kind::kXlaOp: { + TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, + handle().builder()->GetShape(handle())); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + return shape; + } + case Kind::kResource: + return TensorShape({}); + case Kind::kInvalid: + return errors::InvalidArgument( + "GetShape() called on invalid XlaExpression"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 0000000000000000000000000000000000000000..bed6761d362a98d344003c1edea342e68c31ef07 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + const Tensor& constant_value() const { return constant_value_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kConstant expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + xla::StatusOr> ResolveConstant( + xla::Client* client) const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + xla::StatusOr GetShape() const; + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp. + xla::XlaOp handle_; + + // The value of the constant, if kind_ == kConstant. + Tensor constant_value_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..84202c931390f2d68f6d381aef0752bfff00a53d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class XlaExpressionTest : public ::testing::Test { + protected: + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + builder_ = absl::make_unique("acomputation"); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); + non_constant_op_ = xla::Parameter( + builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); + resource_ = absl::make_unique( + XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), + DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), + /*tensor_array_multiple_writes_aggregate=*/false); + } + + xla::Client* client_; + std::unique_ptr builder_; + Tensor constant_; + xla::XlaOp op_; + xla::XlaOp non_constant_op_; + std::unique_ptr resource_; +}; + +TEST_F(XlaExpressionTest, Kind) { + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind()); + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind()); + EXPECT_TRUE(XlaExpression::Kind::kConstant == + XlaExpression::Constant(constant_).kind()); + EXPECT_TRUE(XlaExpression::Kind::kXlaOp == + XlaExpression::XlaOp(op_, DT_INT32).kind()); + EXPECT_TRUE(XlaExpression::Kind::kResource == + XlaExpression::Resource(resource_.get()).kind()); +} + +TEST_F(XlaExpressionTest, HumanString) { + EXPECT_EQ("invalid", XlaExpression().HumanString()); + EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString()); + EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString()); + EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString()); + EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString()); +} + +TEST_F(XlaExpressionTest, AsXlaOp) { + xla::XlaOp op_as_op = + XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get()); + EXPECT_TRUE(op_.IsIdenticalTo(op_as_op)); + + xla::XlaOp const_as_op = + XlaExpression::Constant(constant_).AsXlaOp(builder_.get()); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, + builder_->BuildConstantSubGraph(const_as_op)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, + client_->ComputeConstant(computation)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), + value)); +} + +TEST_F(XlaExpressionTest, GetShape) { + EXPECT_FALSE(XlaExpression().GetShape().ok()); + EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok()); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape, + XlaExpression::Resource(resource_.get()).GetShape()); + EXPECT_EQ(TensorShape({}), resource_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape, + XlaExpression::XlaOp(op_, DT_INT32).GetShape()); + EXPECT_EQ(TensorShape({}), op_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape, + XlaExpression::Constant(constant_).GetShape()); + EXPECT_EQ(TensorShape({}), constant_shape); +} + +TEST_F(XlaExpressionTest, ResolveConstant) { + EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); + EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); + EXPECT_FALSE( + XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional op_constant, + XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); + ASSERT_TRUE(op_constant.has_value()); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + + TF_ASSERT_OK_AND_ASSIGN(absl::optional op_nonconstant, + XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) + .ResolveConstant(client_)); + EXPECT_FALSE(op_nonconstant.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional constant_constant, + XlaExpression::Constant(constant_).ResolveConstant(client_)); + ASSERT_TRUE(constant_constant.has_value()); + test::ExpectTensorEqual(constant_, *constant_constant); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dd3498ef7aa242d3ad946cae5f60bc2c8853a342..8dd8def0549f2b39d4c9863bb535f19703c3ef22 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().valid() || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle(); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); return expression; } -// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. -static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { +// Assigns an XlaExpression to a tensor on an XLA compilation device. +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK(!expression->handle().valid()); - return const_cast(expression); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; } -// Retrieves the XlaOp from an input Tensor to an Op. This computation was -// constructed by an Op that executed previously and created the output Tensor -// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. -static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); +const XlaExpression& XlaOpKernelContext::InputExpression(int index) { + return *CastExpressionFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(int index) { - return GetComputationFromTensor(context_->input(index)); +const XlaExpression& XlaOpKernelContext::InputExpression( + absl::string_view name) { + return *CastExpressionFromTensor(GetInputTensorByName(name)); } -const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { - return GetComputationFromTensor(GetInputTensorByName(name)); +xla::XlaOp XlaOpKernelContext::Input(int index) { + return InputExpression(index).AsXlaOp(builder()); +} + +xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { + return InputExpression(name).AsXlaOp(builder()); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -125,77 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name, Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { - const Tensor& tensor = context_->input(index); - TensorShape new_shape(new_dims); - if (tensor.NumElements() != new_shape.num_elements()) { - return errors::InvalidArgument( - context_->op_kernel().name(), " input ", index, " has shape ", - tensor.shape().DebugString(), - " but was asked to be reshaped to incompatible shape ", - new_shape.DebugString()); - } - const XlaExpression* expression = CastExpressionFromTensor(tensor); - - auto copy_tensor_to_literal = [](const Tensor& tensor, - xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); - }; - - // If the tensor has a known constant value, there is no need to invoke XLA. - if (expression->has_constant_value()) { - Tensor temp(tensor.dtype()); - if (!temp.CopyFrom(expression->constant_value(), new_shape)) { - // This should never happen. The constant should have a shape compatible - // with the enclosing Tensor. - return errors::Internal("Incompatible shapes in ConstantInputReshaped."); - } - - return copy_tensor_to_literal(temp, constant_literal); - } - - // Make sure we treat zero-element tensors as constant. - if (new_shape.num_elements() == 0) { - Tensor temp(tensor.dtype(), new_shape); - - return copy_tensor_to_literal(temp, constant_literal); - } - - xla::XlaOp handle = expression->handle(); - if (new_shape != tensor.shape()) { - // Reshape the handle to the desired shape. - handle = xla::Reshape(handle, new_shape.dim_sizes()); - } - - // The XLA layout is specified minor to major, and TensorFlow's minor - // dimension is the last one. - std::vector layout_indices(new_shape.dims()); - std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - - xla::StatusOr is_constant = builder()->IsConstant(handle); - if (!is_constant.ok()) { - Status status = is_constant.status(); + XlaExpression e = InputExpression(index); + xla::StatusOr> constant_or_status = + e.ResolveConstant(compiler()->client()); + if (!constant_or_status.ok()) { + Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); return status; } - - if (!is_constant.ValueOrDie()) { + absl::optional constant = constant_or_status.ValueOrDie(); + if (!constant.has_value()) { return errors::InvalidArgument( "Input ", index, " to ", context_->op_kernel().type_string(), " operator must be a compile-time constant.\n" @@ -208,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped( "stateful operation such as a random number generator."); } - // Ask the XLA compiler to evaluate the data handle to a literal. - xla::StatusOr constant_graph = - builder()->BuildConstantSubGraph(handle); - if (!constant_graph.ok()) { - return errors::Internal( - "Error getting a compile-time constant graph for ", - context_->op_kernel().name(), " input ", index, - ".\nError: ", constant_graph.status().error_message()); - } - xla::StatusOr computed = compiler()->client()->ComputeConstant( - constant_graph.ValueOrDie(), &layout); - if (!computed.ok()) { - return errors::Internal("Error evaluating ", context_->op_kernel().name(), - " input ", index, - " as a compile-time constant.\nError: ", - computed.status().error_message()); + Tensor temp(constant->dtype()); + if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + constant->shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + TensorShape(new_dims).DebugString()); } - *constant_literal = std::move(computed).ValueOrDie(); + TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); return Status::OK(); } @@ -322,6 +259,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + absl::string_view name, std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -372,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(GetComputationFromTensor(input)); + handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -413,9 +359,12 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, XlaContext& xla_context = XlaContext::Get(ctx); TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, + xla::Shape representation_shape, xla_context.RepresentationShape(variable->shape(), variable->type())); - if (representation_shape == variable->shape()) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); + if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { *value = variable->value(); } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); @@ -455,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, - Tensor** output) { - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - if (expected_output_dtype(index) == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in its - // value. - // TODO(jpienaar): This should be refactored to stop masquerading - // XlaExpressions as Tensors. - *output = new Tensor(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, *output)); - context_->set_output(index, **output); - } else { - TensorShape tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); - TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); +void XlaOpKernelContext::SetOutputExpression(int index, + const XlaExpression& expression) { + Status status = [&] { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + Tensor* output = nullptr; + // Provides a special behavior for DT_VARIANT: a variant is treated as + // DT_UINT8 scalar as the type to allow mapping for variant to more generic + // types. + if (expression.dtype() == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in + // its value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, output)); + context_->set_output(index, *output); + } else { + TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); + TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); + } + AssignExpressionToTensor(output, expression); + return Status::OK(); + }(); + if (!status.ok()) { + SetStatus(status); } - return Status::OK(); } void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { - // Makes the host Tensor that will refer to the expression. - Tensor* output = nullptr; - auto shape_or = builder()->GetShape(handle); - if (!shape_or.ok()) { - SetStatus(shape_or.status()); - return; - } - - OP_REQUIRES_OK(context_, - allocate_output(index, shape_or.ValueOrDie(), &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); + SetOutputExpression( + index, + XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); } void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { - const TensorShape& shape = constant.shape(); - - xla::BorrowingLiteral literal; - OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - - xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); - CHECK(handle.valid()); - - // Make the Tensor that will refer to the expression. - Tensor* output = nullptr; - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); - expression->set_constant_value(constant); -} - -void XlaOpKernelContext::SetInvalidOutput(int index) { - Tensor* output = nullptr; - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape({}), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::XlaOp handle; - expression->set_handle(handle); + SetOutputExpression(index, XlaExpression::Constant(constant)); } void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { - Tensor* output = nullptr; - // The shape of the output tensor is the shape of the resource itself - // (i.e., a scalar), not the shape of the resource's value. - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape(), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_resource(resource); + SetOutputExpression(index, XlaExpression::Resource(resource)); } Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { @@ -570,10 +482,13 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, xla_context.RepresentationShape(shape, type)); - if (shape != representation_shape) { - handle = xla::Reshape(handle, representation_shape.dim_sizes()); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { + handle = xla::Reshape(handle, + xla::AsInt64Slice(representation_shape.dimensions())); } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index aa00a454968ad29495e34dc080e55b62bb0b5f7b..c06efa2c474c5ec3cb5d75d94ba15d4096faa085 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -88,9 +88,9 @@ class XlaOpKernelContext { // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::XlaOp& Input(int index); + xla::XlaOp Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(absl::string_view name); + xla::XlaOp Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -111,14 +111,6 @@ class XlaOpKernelContext { Status ConstantInput(int index, xla::Literal* constant_literal); Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); - // Evaluates input `index`, reshapes it to `new_shape` if new_shape != - // InputShape(index), and stores it in `*constant_literal`. If the input - // cannot be evaluated, e.g., because it depends on unbound parameters, - // returns a non-Ok status. If InputShape(index).num_elements() != - // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, absl::Span new_dims, - xla::Literal* constant_literal); - // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); Status ConstantInputAsIntScalar(absl::string_view name, int64* out); @@ -134,6 +126,8 @@ class XlaOpKernelContext { // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. Status ConstantInputReshapedToIntVector(int index, std::vector* out); + Status ConstantInputReshapedToIntVector(absl::string_view name, + std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); @@ -148,6 +142,10 @@ class XlaOpKernelContext { Status ConstantInputList(absl::string_view name, std::vector* literals); + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + // Outputs int num_outputs() const { return context_->num_outputs(); } @@ -165,9 +163,8 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output `index` to an invalid value. - // Any subsequent attempt to consume this output will cause an error. - void SetInvalidOutput(int index); + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } @@ -255,10 +252,13 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); - // Wraps OpKernelContext's allocate_output method while providing special - // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the - // type to allow mapping for variant to more generic types. - Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + Status ConstantInputReshaped(int index, absl::Span new_dims, + xla::Literal* constant_literal); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 9f00de708cc5aceb2c1e397663bc3bba8705bda4..dcd0e9c5c1f20c07c6d2b6fd7315a861817bc523 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -129,21 +130,27 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + cpu_global_jit + ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = true; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; registration.compile_resource_ops = false; } return nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 45a40c0acc07805b422591fd7ea3fcb131db8471..0bdd4a1085445420a5147756daac4a54f4725f11 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -66,19 +66,26 @@ class XlaOpRegistry { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. string compilation_device_name; - // Do operators assigned to this device require compilation? - bool requires_compilation; - - // If !requires_compilation, should we try to JIT operators on this device - // when XLA JIT compilation is enabled globally via the SessionOptions? - // (It is still possible to explicitly mark operators to JIT compile, even - // if enable_jit_by_default is false.) - bool enable_jit_by_default; + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? bool compile_resource_ops = false; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 63b09c8f02a60e91576544d13227d29f56d3e88c..a322eb9015e829fd468133f3de6c12aad7e4ff74 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,6 +26,19 @@ limitations under the License. namespace tensorflow { +/*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) { + switch (kind) { + case XlaResource::kInvalid: + return "invalid"; + case XlaResource::kVariable: + return "variable"; + case XlaResource::kStack: + return "stack"; + case XlaResource::kTensorArray: + return "tensorarray"; + } +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index aa9ce1b171f11ea0de4db0123098729c1c97f93a..857b9a928bb824656f637b2b1ca2fc02a1bef139 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -35,6 +36,7 @@ class XlaResource { kTensorArray, kStack, }; + static absl::string_view KindToString(Kind kind); XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d6b60c5f9916520ba7585824171aad1548610da6..91096cf1d043eb652756f77b7594780124260766 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -68,7 +68,7 @@ cc_library( visibility = [":friends"], deps = [ ":xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", ], ) @@ -735,6 +735,70 @@ tf_cc_test( ], ) +cc_library( + name = "parse_flags_from_env", + srcs = ["parse_flags_from_env.cc"], + hdrs = ["parse_flags_from_env.h"], + deps = + [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "parse_flags_from_env_test", + srcs = ["parse_flags_from_env_test.cc"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "debug_options_flags", + srcs = [ + "debug_options_flags.cc", + "debug_options_parsers.h", + ], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "debug_options_parsers_test", + size = "small", + srcs = [ + "debug_options_parsers.h", + "debug_options_parsers_test.cc", + ], + deps = + [ + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 0cbe68d7efd9fe2ea46b312763437e1b8c986d25..42da0ebf4992884187bbe21701a44d8ba2fccd64 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -68,6 +68,7 @@ cc_library( deps = [ ":global_data", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", @@ -76,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -236,13 +236,13 @@ tf_cc_test( deps = [ ":xla_builder", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index f5f8d5c6b1fe265069992fe92acaa229647d4e8c..eef2844e0df6aaf509881535f41493673fbeeee5 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -210,11 +210,10 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const XlaComputation& computation, absl::Span arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteGraphRequest request; +StatusOr Client::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions* execution_options) { + CompileRequest request; *request.mutable_computation() = computation.proto(); if (execution_options == nullptr) { @@ -222,6 +221,34 @@ StatusOr> Client::Execute( } else { *request.mutable_execution_options() = *execution_options; } + if (request.execution_options().device_handles_size() > 1) { + return InvalidArgument( + "Compiling with multiple device handles is not supported. Use " + "'Execute' instead."); + } + + // The argument shapes affect how the computation is compiled. + for (const auto& arg_shape : argument_shapes) { + *request.add_input_shape_with_layout() = arg_shape; + } + + CompileResponse response; + VLOG(1) << "making compile request: " << request.ShortDebugString(); + Status s = stub_->Compile(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + TF_RET_CHECK(response.has_handle()); + return response.handle(); +} + +StatusOr> Client::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { + ExecuteRequest request; + *request.mutable_handle() = handle; for (GlobalData* argument : arguments) { CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); @@ -229,7 +256,7 @@ StatusOr> Client::Execute( ExecuteResponse response; VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->ExecuteGraph(&request, &response); + Status s = stub_->Execute(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -238,15 +265,62 @@ StatusOr> Client::Execute( if (execution_profile != nullptr) { *execution_profile = response.profile(); + } + + return absl::make_unique(stub_, response.output()); +} + +StatusOr> Client::Execute( + const XlaComputation& computation, absl::Span arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + if (execution_options != nullptr && + execution_options->device_handles_size() > 1) { + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); + } + } + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); + } + + // The argument shapes affect how the computation is compiled. + std::vector arg_shapes(arguments.size()); + for (int i = 0; i < arguments.size(); i++) { + TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); + } + + TF_ASSIGN_OR_RETURN(auto handle, + Compile(computation, arg_shapes, execution_options)); + + TF_ASSIGN_OR_RETURN(auto result, + Execute(handle, arguments, execution_profile)); + + if (execution_profile != nullptr) { if (VLOG_IS_ON(1)) { TF_ASSIGN_OR_RETURN( auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); + ExecutionStatsAsString(computation, *execution_profile)); VLOG(1) << execution_stats; } } - return absl::make_unique(stub_, response.output()); + return std::move(result); } StatusOr>> Client::ExecuteParallel( @@ -274,10 +348,11 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { + for (size_t i = 0; i < response.responses_size(); ++i) { outputs.push_back( absl::make_unique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { + if (i < computations.size() && + computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } } @@ -390,8 +465,7 @@ StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); + GetComputationStats(computation, GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 6f4d33c469f1f885cfeef546e3981dc3417ef71f..d0ac4703c632e0e01d3c8911594b46fedf28930d 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -40,6 +40,31 @@ class Client { explicit Client(ServiceInterface* stub); virtual ~Client(); + // Compile the computation with the given argument shapes and returns the + // handle to the compiled executable. The compiled executable is cached on the + // service, and the returned handle can be used for exection without + // re-compile. + // * The shape and layout of the arguments being executed with will affect how + // the computation is compiled. If argument_shapes is empty, the parameters' + // shape and layout will be used in the compilation. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_options.device_handles should be empty. If you need + // non-empty device handles, call 'Execute' instead. + StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions* execution_options = nullptr); + + // Executes the compiled executable for the given handle with the given + // arguments and returns the global data that was produced from the execution. + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global // data that was produced from the execution. // * If execution_options is not nullptr, these options are passed to the diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9..08a887a6e4660cb2528f0ec7244b7ccc540808d2 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -265,6 +265,22 @@ XlaOp Digamma(XlaOp input) { return result; } +// Implements Banker's rounding: numbers that are equidistant between two +// integers are rounded towards even. +XlaOp RoundToEven(XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); + + auto round_val = xla::Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + round_val + one, round_val); +} + // Trigonometric functions. // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index a6cafd42077367bf23ffa1f45eab31c01dc31b16..3f06d04b9ae98b3aa75e68cd07810b2b4c24d280 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -51,6 +51,10 @@ XlaOp Lgamma(XlaOp input); // Computes an approximation of the digamma function. XlaOp Digamma(XlaOp input); +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + // Trigonometric functions // Computes the arc cosine of 'x'. diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 14c259a7fa2a47642663b65d2785e5bbdc040cfd..ae2ea225d1aadd7b3a794eabeca866c498f34760 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -136,5 +136,17 @@ XLA_TEST_F(MathTest, Digamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, RoundToEven) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-1.4, -1.5, -2.5, -0.5, 0, 0.5, 1.5, 2.5, 3.5, 4.5}); + RoundToEven(x); + + std::vector expected = {-1.0, -2.0, -2.0, -0.0, 0, + 0.0, 2.0, 2.0, 4.0, 4.0}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9..e49451ca9708ab506d11af5f9855db245674864c 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -60,8 +60,8 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. // - // The given ExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. Status ValidateExecutionOptions( const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -69,8 +69,8 @@ class LocalExecutable { // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. // - // The given ServiceExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -114,8 +114,8 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr> Compile( const XlaComputation& computation, const absl::Span argument_layouts, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index f9c23b44810a52ae4dd40cc838e6cb575cb44445..0a587725d20507555382ef0657bdc08369a7fbac 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2305,6 +2305,19 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, }); } +XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension)); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, + {operand}); + }); +} + StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(first_error_); @@ -3158,4 +3171,8 @@ XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { return builder->Iota(shape, iota_dimension); } +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) { + return operand.builder()->GetDimensionSize(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 908a616b4ead8820b5df991c3bc0b2f6724087ef..68314a026eab0db3eaf321f0fa53c016d79882ba 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -933,6 +933,8 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); + XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1355,6 +1357,8 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); }; // RAII-style object: sets the current sharding assignment in builder on @@ -2129,6 +2133,10 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& grad_output, float epsilon, int64 feature_index); +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + // Implementation details below this point. template diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index dfe5fd5eb23ca51d2a449106a21293405a3dab6f..8aa85c3cd63c9b0aeb55d2cebbb989b6432ac959 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -43,7 +43,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -54,7 +54,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -349,6 +349,15 @@ TEST_F(XlaBuilderTest, CollectivePermute) { EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); } +TEST_F(XlaBuilderTest, GetDimensionSize) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + GetDimensionSize(x, 1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc similarity index 96% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc rename to tensorflow/compiler/xla/debug_options_flags.cc index 3ed3afcfcede20fbf5c7d4f004378817febeb4c7..033887d7c11bb530d70f0653f26c61bcbfe1e321 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -13,17 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace legacy_flags { - namespace { DebugOptions* flag_values; @@ -101,8 +99,8 @@ void AllocateFlags() { [](string comma_separated_values) { auto* extra_options_map = flag_values->mutable_xla_backend_extra_options(); - impl::parse_xla_backend_extra_options(extra_options_map, - comma_separated_values); + parse_xla_backend_extra_options(extra_options_map, + comma_separated_values); return true; }; @@ -111,8 +109,8 @@ void AllocateFlags() { [](string reduce_precision_option_value) { HloReducePrecisionOptions* option_proto = flag_values->add_hlo_reduce_precision_options(); - return impl::parse_xla_reduce_precision_option( - option_proto, reduce_precision_option_value); + return parse_xla_reduce_precision_option(option_proto, + reduce_precision_option_value); }; flag_objects = new std::vector({ @@ -353,5 +351,4 @@ xla::DebugOptions GetDebugOptionsFromFlags() { return *flag_values; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h similarity index 81% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.h rename to tensorflow/compiler/xla/debug_options_flags.h index b53157f59c61cf4e0850e006ad3656f4be63a936..60e59abc2a2e0f1cce3de1afc928f9fe36f75b33 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ #include @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Appends flag definitions for debug options to flag_list. void AppendDebugOptionsFlags(std::vector* flag_list); @@ -32,7 +31,6 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // first. xla::DebugOptions GetDebugOptionsFromFlags(); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/debug_options_parsers.h similarity index 94% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h rename to tensorflow/compiler/xla/debug_options_parsers.h index ee7eb019c07cf898e48886955b18710146644cac..80aadfd5ece0e768afaf1842d2b6c5b11c288b55 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ #include #include "absl/strings/numbers.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla.pb.h" namespace xla { -namespace legacy_flags { -namespace impl { template void parse_xla_backend_extra_options(T* extra_options_map, @@ -140,8 +138,6 @@ inline bool parse_xla_reduce_precision_option( return true; } -} // namespace impl -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/debug_options_parsers_test.cc similarity index 88% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc rename to tensorflow/compiler/xla/debug_options_parsers_test.cc index 6f197aec53c7596e84437a03affa9118f22f5a1d..8003c3496d5df9be2ff8a99bc171972c8e090c43 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/debug_options_parsers_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" #include #include @@ -23,13 +23,12 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace xla { -namespace legacy_flags { // Test that the xla_backend_extra_options flag is parsed correctly. TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { std::unordered_map test_map; string test_string = "aa=bb,cc,dd=,ee=ff=gg"; - impl::parse_xla_backend_extra_options(&test_map, test_string); + parse_xla_backend_extra_options(&test_map, test_string); EXPECT_EQ(test_map.size(), 4); EXPECT_EQ(test_map.at("aa"), "bb"); EXPECT_EQ(test_map.at("cc"), ""); @@ -41,7 +40,7 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -56,7 +55,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot;"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -71,7 +70,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -84,7 +83,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -96,7 +95,6 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { diff --git a/tensorflow/compiler/xla/execution_options_util.cc b/tensorflow/compiler/xla/execution_options_util.cc index e83ff7cddd675197c7f6d7018257edb4c25b6228..cf569863bbe1c92bdcafb133d49dcf5ae8890ffe 100644 --- a/tensorflow/compiler/xla/execution_options_util.cc +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" namespace xla { ExecutionOptions CreateDefaultExecutionOptions() { ExecutionOptions execution_options; - *(execution_options.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); + *(execution_options.mutable_debug_options()) = GetDebugOptionsFromFlags(); return execution_options; } diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md index 5376a04669d7c17a2fed8cdab46e21277049bf72..ded1e582b24c7a45acc6b61ba9c018fa2a1e7db7 100644 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -58,7 +58,7 @@ sess = tf.Session(config=config) > compiled for the CPU. JIT compilation for CPU operations must be done via > the manual method documented below. -#### Manual +#### Manual with experimental_jit_scope() JIT compilation can also be turned on manually for one or more operators. This is done by tagging the operators to compile with the attribute @@ -79,6 +79,16 @@ The `_XlaCompile` attribute is currently supported on a best-effort basis. If an operator cannot be compiled, TensorFlow will silently fall back to the normal implementation. +#### Manual with xla.compile() + +Unlike experimental_jit_scope() which silently falls back to normal Tensorflow +on uncompilable operator, xla.compile() returns an explicit error. This is +useful if you want more predictable behaviors from XLA compilation. + +Please see +[xla.compile() tutorial Colab](https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb) +for how to use it. + ### Placing operators on XLA devices Another way to run computations via XLA is to place an operator on a specific diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index a3cdfe19b2e3470bb903cce6cbc79d8d13cc8349..73a9db75f6bf090bba5c3534f14d8ebfa421b5bb 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1339,6 +1339,22 @@ the semantics for `tf.gather_nd`. index `X` in the gather indices array picks an entire row and the result is the concatenation of all these rows. +## GetDimensionSize + +See also +[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the size of the given dimension of the operand. The operand must be +array shaped. + + `GetDimensionSize(operand, dimension)` + +| Arguments | Type | Semantics | +| ----------- | ------- | --------------------------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | +: : : dimension : + ## GetTupleElement See also diff --git a/tensorflow/compiler/xla/g3doc/overview.md b/tensorflow/compiler/xla/g3doc/overview.md index 6a172c3ae159974fb4a34ec422a9a96079b0814a..d3428b7276131e8f406f60cfea9a9346c5478433 100644 --- a/tensorflow/compiler/xla/g3doc/overview.md +++ b/tensorflow/compiler/xla/g3doc/overview.md @@ -4,11 +4,8 @@ -> Note: XLA is experimental and considered alpha. Most use cases will not -> see improvements in performance (speed or decreased memory usage). We have -> released XLA early so the Open Source Community can contribute to its -> development, as well as create a path for integration with hardware -> accelerators. +> Note: XLA is still under development. Some use cases will not +> see improvements in speed or decreased memory usage. XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes TensorFlow computations. The results are improvements in diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a83e3f78598e7c0afaada43b8ae1ba71ad4839d6 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -0,0 +1,412 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "f4TSNCvpENrW" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "vamNSA0vEP-m" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": {}, + "colab_type": "code", + "id": "xD_ydfejEV7H" + }, + "outputs": [], + "source": [ + "#@title MIT License\n", + "#\n", + "# Copyright (c) 2017 François Chollet\n", + "#\n", + "# Permission is hereby granted, free of charge, to any person obtaining a\n", + "# copy of this software and associated documentation files (the \"Software\"),\n", + "# to deal in the Software without restriction, including without limitation\n", + "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n", + "# and/or sell copies of the Software, and to permit persons to whom the\n", + "# Software is furnished to do so, subject to the following conditions:\n", + "#\n", + "# The above copyright notice and this permission notice shall be included in\n", + "# all copies or substantial portions of the Software.\n", + "#\n", + "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", + "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", + "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n", + "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", + "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n", + "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n", + "# DEALINGS IN THE SOFTWARE." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1oSi4lHFt3z" + }, + "source": [ + "# Welcome to `xla.compile()` tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/xla/jit#turning_on_jit_compilation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "v9YbsuLZaBXy" + }, + "source": [ + "xla.compile() is a new experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/).\n", + "\n", + "Please run all code blocks in order." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "45kUPj5ZFrRa" + }, + "outputs": [], + "source": [ + "import tensorflow as tf" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9NMQFjroSMns" + }, + "source": [ + "Imports XLA library, which includes xla.compile() experimental API." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-Uggy03rSGJm" + }, + "outputs": [], + "source": [ + "from tensorflow.contrib.compiler import xla" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GZVNiRmTDV-5" + }, + "source": [ + "Define some necessary constants and prepare MNIST dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "f37TSEGvGX4_" + }, + "outputs": [], + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "TiVXchblG5hK" + }, + "outputs": [], + "source": [ + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n", + "\n", + "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n", + "images, labels = iterator.get_next()\n", + "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", + "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x_ZehpZP-SfS" + }, + "source": [ + "## Defines build_mnist_model function to construct model\n", + "\n", + "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", + "\n", + "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ZbhJl_WvGa3g" + }, + "outputs": [], + "source": [ + "def build_mnist_model(x, y_):\n", + " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", + "\n", + " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n", + " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", + "\n", + " return y, train_step" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7Jh3lyQHDfM9" + }, + "source": [ + "## Uses xla.compile with build_mnist_model function to enable XLA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EtDwez_1gjzv" + }, + "source": [ + "Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kYpCXCdRHNuN" + }, + "outputs": [], + "source": [ + "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4giQh62IrZGF" + }, + "source": [ + "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", + "\n", + "xla.compile does not return any\n", + "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n", + "\n", + "In pseudo-code, xla.compile's implementation looks as follows:\n", + "\n", + "---\n", + "```\n", + "# Ask Tensorflow to execute code in XLA-friendly manner\n", + "\n", + "y, train_step = build_mnist_model(images, labels)\n", + "with tf.control_dependencies([train_step]):\n", + " y = tf.identity(y)\n", + "\n", + "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n", + "```\n", + "---\n", + "\n", + "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "TPGas4jjFLZl" + }, + "source": [ + "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EZD1m_n1DxAF" + }, + "source": [ + "## Trains and tests the model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "qe28bAHNHUG2" + }, + "outputs": [], + "source": [ + "# Creates session and initialize all variables.\n", + "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", + "sess = tf.Session()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qgsKmz3n2UiW" + }, + "source": [ + "Following code block trains model.\n", + "\n", + "Note that evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_GxF6jTRHVuA" + }, + "outputs": [], + "source": [ + "# Feeds training dataset\n", + "sess.run(iterator.make_initializer(train_ds))\n", + "\n", + "# Runs TRAIN_STEPS steps\n", + "for i in range(TRAIN_STEPS):\n", + " sess.run(y)\n", + "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "dHlQlRSRHXD1" + }, + "outputs": [], + "source": [ + "# Tests trained model\n", + "\n", + "# Feeds testing dataset\n", + "sess.run(iterator.make_initializer(test_ds))\n", + "\n", + "# Calculates accuracy\n", + "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ynJQIuzjHYOb" + }, + "outputs": [], + "source": [ + "# Cleans up session\n", + "sess.close()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "xla.compile() Tutorial", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 93522d2ca87a7eba8d3c7533785c54e63ce507b0..fa94d0afb4c9280b8f8fa9642c1b0ab7285ee6f3 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -24,8 +24,7 @@ limitations under the License. namespace xla { namespace { -void SetMinorToMajorLayout(Shape* shape, - std::initializer_list dimensions) { +void SetMinorToMajorLayout(Shape* shape, std::vector dimensions) { shape->mutable_layout()->clear_minor_to_major(); for (auto dimension : dimensions) { shape->mutable_layout()->add_minor_to_major(dimension); @@ -122,7 +121,7 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { std::vector linear_indexes = {0, 1439999999, 1145567336, 43883404, 617295214, 1117613654}; - std::vector> minor_to_major_orders; + std::vector> minor_to_major_orders; minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0}); minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6}); minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3}); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD deleted file mode 100644 index 3e79129aafd234e5eab05d205f2017b54057795e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ /dev/null @@ -1,82 +0,0 @@ -# Legacy command-line flags for the XLA libraries. - -# Please do not add more flags to this package. - -# The XLA libraries were written in an environment that allowed command-line -# flags to be scattered freely throughout the libraries. This model, while -# initially convenient, leads to a proliferation in unused command-line flags -# in tests and binaries, and serious problems in servers, where one might wish -# parameters to be different in independent RPC calls to the same routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "parse_flags_from_env", - srcs = ["parse_flags_from_env.cc"], - hdrs = ["parse_flags_from_env.h"], - deps = - [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "parse_flags_from_env_test", - srcs = ["parse_flags_from_env_test.cc"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "debug_options_flags", - srcs = [ - "debug_options_flags.cc", - "debug_options_parsers.h", - ], - hdrs = ["debug_options_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "debug_options_parsers_test", - size = "small", - srcs = [ - "debug_options_parsers.h", - "debug_options_parsers_test.cc", - ], - deps = - [ - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 80dfdb83c35183ef2f2f80b7cd13589c14e4a50e..cb00a0ab16df851ccbd4bba960b92ea83157867d 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -1434,10 +1434,14 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case U8: return EqualElementsInternal(other, &multi_index); + case S16: + return EqualElementsInternal(other, &multi_index); case S32: return EqualElementsInternal(other, &multi_index); case S64: return EqualElementsInternal(other, &multi_index); + case U16: + return EqualElementsInternal(other, &multi_index); case U32: return EqualElementsInternal(other, &multi_index); case U64: @@ -1506,6 +1510,11 @@ bool LiteralBase::IsAll(int8 value) const { return AllElementsEqualValue(piece.data(), value); } return false; + case U16: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; case U32: if (value >= 0) { return AllElementsEqualValue(piece.data(), value); @@ -1518,6 +1527,8 @@ bool LiteralBase::IsAll(int8 value) const { return false; case S8: return AllElementsEqualValue(piece.data(), value); + case S16: + return AllElementsEqualValue(piece.data(), value); case S32: return AllElementsEqualValue(piece.data(), value); case S64: @@ -1739,12 +1750,16 @@ bool LiteralBase::IsZero(absl::Span indices) const { switch (shape().element_type()) { case U8: return Get(indices) == 0; + case U16: + return Get(indices) == 0; case U32: return Get(indices) == 0; case U64: return Get(indices) == 0; case S8: return Get(indices) == 0; + case S16: + return Get(indices) == 0; case S32: return Get(indices) == 0; case S64: @@ -1802,6 +1817,20 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case U16: + *proto->mutable_u16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_u16s()); + } + break; + case S16: + *proto->mutable_s16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_s16s()); + } + break; case F16: *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); @@ -1916,6 +1945,22 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case S16: { + const string& s(proto.s16s()); + TF_RET_CHECK(data().size() * sizeof(int16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case U16: { + const string& s(proto.u16s()); + TF_RET_CHECK(data().size() * sizeof(uint16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; case F16: { const string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 9d34d9d504156c4b9e645ccfa7cdbd346e51390b..b044f0ad73f13a0599e77f1f43888bc974e31f73 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -141,8 +141,10 @@ int64 RecursiveElementCount(const Shape& shape) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else { + } else if (ShapeUtil::IsArray(shape)) { return ShapeUtil::ElementsIn(shape); + } else { + return 0; } } diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 3511760ac1cad15941910b8fa74f0b7f36844e92..8cec37897a94472d61d2346cf4cab03c45033800 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -1394,6 +1394,28 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { EXPECT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, CopyFromProto_u16) { + uint16 u1(0xabcd); + uint16 u2(0x1234); + + const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12, + 0x34, 0x12, 0xcd, 0xab}; + LiteralProto p; + p.mutable_shape()->set_element_type(U16); + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(4); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + p.clear_u16s(); + p.set_u16s(uint16_vals, 8); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); + ASSERT_EQ(4, r.size()); + EXPECT_EQ(u1, r[0]); + EXPECT_EQ(u2, r[1]); + EXPECT_EQ(u2, r[2]); + EXPECT_EQ(u1, r[3]); +} + TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc similarity index 98% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc rename to tensorflow/compiler/xla/parse_flags_from_env.cc index 2a4e49b05aa0d1eed2197095694cfc6aa8814983..40481331b6992103e10e3fe635a030d3bdffebc9 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried static const char kWS[] = " \t\r\n"; // whitespace @@ -202,5 +201,4 @@ void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { *pargv = &env_argv->argv; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h similarity index 90% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h rename to tensorflow/compiler/xla/parse_flags_from_env.h index b54482ad2ba2224c781861341a80ceb878ffd343..fe86ee687f8482aaffc2ebe04a723d9a22f2cce6 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ +#define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ // This module exports ParseFlagsFromEnv(), which allows other modules to parse // flags from the environtment variable TF_XLA_FLAGS, or (if the first @@ -50,7 +50,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet // unrecognized flags passed in from the environment, and return its @@ -60,7 +59,6 @@ bool ParseFlagsFromEnv(const std::vector& flag_list); // Used only for testing. Not to be used by clients. void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ +#endif // TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc similarity index 96% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc rename to tensorflow/compiler/xla/parse_flags_from_env_test.cc index 138c0c852e2bb0527d171f25b4d96cedc5671516..edd6538402d6ceee292ca6a265f490be9709d3ae 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include #include @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Test that XLA flags can be set from the environment. // Failure messages are accompanied by the text in msg[]. @@ -159,12 +158,11 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) { } } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { // Save name of binary so that it may invoke itself. - xla::legacy_flags::binary_name = argv[0]; + xla::binary_name = argv[0]; bool recursing = false; xla::int32 int_flag = 1; const std::vector flag_list = { @@ -173,7 +171,7 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); + bool parse_ok = xla::ParseFlagsFromEnv(flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 21685c4a5b90f76440e4cf10cce004b6cf925cc8..63ac1c6649210cbae9e238a74e0a45fb8ee4da63 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") py_library( name = "xla_client", @@ -66,6 +67,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -81,6 +83,7 @@ tf_py_wrap_cc( srcs = ["xla.i"], swig_includes = [ "local_computation_builder.i", + "//tensorflow/python:platform/base.i", ], deps = [ ":local_computation_builder", @@ -89,5 +92,7 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ], + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service:gpu_plugin", + ]), ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b1fae826ab1903fb73541a7ae32b5cc57b3b92a7..4d2a37cfac3e0e89d189f168031e5db44ca5d410 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -56,6 +57,12 @@ tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; +string* GetPlatformNameString() { + static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = + new string("Host"); + return platform_name_string; +} + Status InitializeReplicaCount(int replica_count) { if (replica_count < 1) { return InvalidArgument("Replica count must be >= 1; got %d.", @@ -72,17 +79,33 @@ Status InitializeReplicaCount(int replica_count) { return Status::OK(); } +Status InitializePlatformName(const string& platform_name) { + string* g_platform_name = GetPlatformNameString(); + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the platform name to %s, but a local XLA service was " + "previously created with a platform name of %s.", + platform_name, *g_platform_name); + } + TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); + *g_platform_name = platform_name; + return Status::OK(); +} + int GetReplicaCount() { tensorflow::mutex_lock lock(g_local_client_mutex); return g_replica_count; } LocalClient* GetOrCreateLocalClient() { + string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { return g_local_client; } LocalClientOptions options; + options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); CHECK(g_local_client != nullptr); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 82f84ddb35bd4455fd3607509c6329457cca47f3..9e617c48bdc5ae4b37c1a1db9a1876bb4c0a6f0d 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -39,6 +39,12 @@ namespace swig { // returned. Status InitializeReplicaCount(int replica_count); +// Initializes the platform name that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializePlatformName(const string& platform_name); + // Returns the replica count that is currently set, regardless of whether the // local XLA service has been instantiated yet or not. int GetReplicaCount(); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index c13d00d2530c7e9321d483a70e4a12361159362d..feabfdb889ca055550c5d1e1c05ca47c1b0bd166 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -977,6 +977,7 @@ tensorflow::ImportNumpy(); %unignore xla; %unignore xla::swig; %unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 07e0e093255b2baf3412852821fe62fa060f6cad..92b0685dbba195405d78867776fe43b5f6c60f4c 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1371,6 +1371,18 @@ def initialize_replica_count(replica_count): c_api.InitializeReplicaCount(replica_count) +def initialize_platform_name(platform_name): + """Initializes the desired platform name to use on XLA service init. + + Args: + platform_name: string name of platform. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + c_api.InitializePlatformName(platform_name) + + def get_replica_count(): """Returns the current replica count used for the XLA service. diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 4e1435fa30a24c320ddbedb84d37b369a3158a54..d8123a6de28ca532819ece4a75cd0b725f8c1bbd 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,11 +47,18 @@ namespace xla { }); } -::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, + const CompileRequest* arg, + CompileResponse* result) { return DelegateRPC( - [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); + [this, arg, result]() { return service_->Compile(arg, result); }); +} + +::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/, + const ExecuteRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Execute(arg, result); }); } ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index ca1b09b648013ad45d806040c5ddcf11d9e5604e..3e586b288a56a22573d0c3b9ae7b2f25fdbf851a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, + const CompileRequest* arg, + CompileResponse* result) override; + + ::grpc::Status Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 7b8ab158e1396d7087a407be180ab44d2e16e121..66abf66cfd6c2f753c5507aa373452ac880e9a29 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Compile(const CompileRequest* request, + CompileResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteGraph(context, *request, response); + return grpc_stub_->Compile(context, *request, response); + }); +} + +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Execute(context, *request, response); }); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 8dfcb761387d608abbb1f62974f49b976a7ff7ff..f02b401399f3e895153f0b08e325bc9c2c2336ec 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status Compile(const CompileRequest* request, + CompileResponse* response) override; + + Status Execute(const ExecuteRequest* request, + ExecuteResponse* response) override; Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 551ae895e05586daec0ffcd425f4950f76bdd50d..e4f332cda22cc5b889bf73f06913b96d6091dc81 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -128,11 +128,14 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. + // Compiles the provided computation into executable. Returns the handle of + // the executable. + rpc Compile(CompileRequest) returns (CompileResponse) {} + + // Invokes the provided executable with the provided global data passed as + // immutable arguments. The request contains the handle to the executable. // Returns global data output and execution timing. - rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { - } + rpc Execute(ExecuteRequest) returns (ExecuteResponse) {} // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index cd8c20d43ea1623b908895ef35a34d92f44f5cc3..19b5c1ca25debf80c7e712854b47384937697d3d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -124,7 +123,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -158,12 +156,12 @@ tf_cc_test( ":bfloat16_propagation", ":bfloat16_support", ":hlo", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -281,7 +279,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -323,7 +321,6 @@ cc_library( ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", - ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal", @@ -365,7 +362,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -402,6 +398,7 @@ cc_library( srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", @@ -420,7 +417,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -466,7 +462,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -519,7 +514,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -568,7 +562,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -591,7 +584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -603,11 +595,11 @@ cc_library( hdrs = ["platform_util.h"], deps = [ ":compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -647,6 +639,7 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", + ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -662,6 +655,7 @@ cc_library( ":source_map_util", ":stream_pool", ":transfer_manager", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -673,7 +667,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", @@ -730,12 +723,12 @@ cc_library( ":computation_layout", ":platform_util", ":service", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -811,6 +804,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", ], @@ -833,6 +827,7 @@ cc_library( ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", @@ -840,7 +835,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -1086,7 +1080,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1103,6 +1096,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_reachability", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1168,7 +1162,6 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1388,6 +1381,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", @@ -1428,7 +1422,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1504,7 +1497,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1556,7 +1548,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1593,7 +1584,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1643,7 +1633,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1695,6 +1685,19 @@ cc_library( ], ) +tf_cc_test( + name = "while_loop_analysis_test", + srcs = ["while_loop_analysis_test.cc"], + deps = [ + ":while_loop_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1703,9 +1706,9 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", + ":hlo_query", ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1717,10 +1720,12 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":hlo", + ":hlo_dce", ":hlo_matchers", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -1751,7 +1756,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1779,7 +1784,7 @@ tf_cc_test( ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1824,7 +1829,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1858,7 +1862,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -2264,7 +2268,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2327,13 +2330,26 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], ) +cc_library( + name = "compilation_cache", + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], + deps = [ + ":executable", + ":hlo_module_config", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "layout_assignment", srcs = [ @@ -2403,14 +2419,13 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2528,7 +2543,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2595,7 +2609,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2657,7 +2670,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2698,7 +2711,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2737,7 +2749,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2809,10 +2820,9 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -3000,7 +3010,6 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -3279,6 +3288,8 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", ], ) @@ -3305,6 +3316,7 @@ cc_library( ":hlo", ":hlo_pass", ":tuple_util", + ":while_loop_analysis", ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -3324,7 +3336,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3354,7 +3366,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3415,7 +3427,7 @@ tf_cc_test( ":indexed_array_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", ], @@ -3512,7 +3524,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 85fc42f74756458ee677e8b53448ceb02f08e834..89e62bd2f0dc02d2d0947ae47e3bb0c9955f103e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include #include #include #include @@ -107,6 +108,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override; + Status HandleAnd(HloInstruction* logical_and) override; + Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcastConvert(HloInstruction* bitcast) override; @@ -141,6 +144,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) override; + Status HandleNegate(HloInstruction* negate) override; + + Status HandleNot(HloInstruction* logical_not) override; + + Status HandleOr(HloInstruction* logical_or) override; + Status HandlePad(HloInstruction* pad) override; Status HandlePower(HloInstruction* power) override; @@ -306,9 +315,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); - // Tries to simplify a slice(pad(...)) where the result of the slice is a - // scalar. - StatusOr TrySimplifySliceOfPad(HloInstruction* slice); + // Tries to simplify a slice where the result of the slice is a scalar. + StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + + // Tries to convert slice(reshape(X)) into reshape(slice(X)) + StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. @@ -423,6 +434,43 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); + // Simplify logical and + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A && True => A + VLOG(10) << "trying transform [A && True => A]: " + << logical_and->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + // True && A => A + VLOG(10) << "trying transform [True && A => A]: " + << logical_and->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // A && False => False + VLOG(10) << "trying transform [A && False => False]: " + << logical_and->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // False && A => False + VLOG(10) << "trying transform [False && A => False]: " + << logical_and->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. HloInstruction* op; @@ -1229,6 +1277,64 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { + // negate(negate(x)) => x + HloInstruction* x; + if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && + ReplaceInstructionIfSameShape(negate, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { + // not(not(x)) => x + HloInstruction* x; + if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && + ReplaceInstructionIfSameShape(logical_not, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); + + // Simplify logical or + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A || True => True + VLOG(10) << "trying transform [A || True => True]: " + << logical_or->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + // True || A => True + VLOG(10) << "trying transform [True || A => True]: " + << logical_or->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // A || False => A + VLOG(10) << "trying transform [A || False => A]: " + << logical_or->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // False || A => A + VLOG(10) << "trying transform [False || A => A]: " + << logical_or->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); @@ -1826,60 +1932,160 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor::TrySimplifySliceOfPad( +StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( HloInstruction* slice) { // Only try to do this for effective scalars. We could do the same for slicing // out larger pieces of padding (replacing with a broadcast of the padding // value), but this is probably not worth it. - if (!ShapeUtil::IsEffectiveScalar(slice->shape()) || - slice->operand(0)->opcode() != HloOpcode::kPad) { + if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { return false; } - VLOG(10) << "Trying to simplify scalar slice of pad"; - // Check there's no internal padding. Again, we could handle that too, since - // everything is statically known, but it's not worth it. - auto pad = Cast(slice->mutable_operand(0)); - auto padding_config = pad->padding_config(); - int64 rank = padding_config.dimensions_size(); - if (HasInteriorPadding(padding_config)) { - VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; - return false; + if (slice->operand(0)->opcode() == HloOpcode::kPad) { + VLOG(10) << "Trying to simplify scalar slice of pad"; + // Check there's no internal padding. Again, we could handle that too, since + // everything is statically known, but it's not worth it. + auto pad = Cast(slice->mutable_operand(0)); + auto padding_config = pad->padding_config(); + int64 rank = padding_config.dimensions_size(); + if (HasInteriorPadding(padding_config)) { + VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; + return false; + } + + // Check whether the scalar we're slicing out falls into the padding. + bool in_padding = [&]() { + for (int64 i = 0; i < rank; ++i) { + int64 start = slice->slice_starts(i); + int64 low = padding_config.dimensions(i).edge_padding_low(); + int64 data = pad->operand(0)->shape().dimensions(i); + if (start >= low && start < low + data) { + return false; + } + } + return true; + }(); + + if (in_padding) { + VLOG(10) << "Folding scalar slice of pad into padding value"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), + pad->mutable_padding_value()))); + return true; + } else { + // We already know the output of the slice is scalar. If the padded + // value is scalar, and it's not in the padding, then it's exactly the + // output value. + bool replaced = + ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + if (replaced) { + VLOG(10) << "Folding scalar slice of pad into padded value"; + } else { + VLOG(10) << "Not folding scalar slice of pad into padded value as they " + "have different shapes."; + } + return replaced; + } } - // Check whether the scalar we're slicing out falls into the padding. - bool in_padding = [&]() { - for (int64 i = 0; i < rank; ++i) { - int64 start = slice->slice_starts(i); - int64 low = padding_config.dimensions(i).edge_padding_low(); - int64 data = pad->operand(0)->shape().dimensions(i); - if (start >= low && start < low + data) { - return false; + if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { + VLOG(10) << "Trying to simplify scalar slice of concat"; + // Only do this for R1, there's no chance of this being useful otherwise. + if (ShapeUtil::Rank(slice->shape()) != 1) { + VLOG(10) << "Not folding, slice is not rank 1"; + return false; + } + HloConcatenateInstruction* concat = + Cast(slice->mutable_operand(0)); + int64 operand_start = 0; + int64 operand_num = 0; + // Weird loop structure to avoid annoying off-by-one errors. + while (true) { + TF_RET_CHECK(operand_num < concat->operand_count()); + const HloInstruction* operand = concat->operand(operand_num); + int64 next_operand_start = operand_start + operand->shape().dimensions(0); + if (next_operand_start > slice->slice_starts(0)) { + break; } + operand_start = next_operand_start; + operand_num++; } - return true; - }(); - if (in_padding) { - VLOG(10) << "Folding scalar slice of pad into padding value"; - TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( - slice, HloInstruction::CreateReshape(slice->shape(), - pad->mutable_padding_value()))); - return true; - } else { - // We already know the output of the slice is scalar. If the padded - // value is scalar, and it's not in the padding, then it's exactly the - // output value. - bool replaced = - ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + bool replaced = ReplaceInstructionIfSameShape( + slice, concat->mutable_operand(operand_num)); if (replaced) { - VLOG(10) << "Folding scalar slice of pad into padded value"; + VLOG(10) << "Folding scalar slice of concat into concat operand"; } else { - VLOG(10) << "Not folding scalar slice of pad into padded value as they " - "have different shapes."; + VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), concat->mutable_operand(operand_num), + {slice->slice_starts(0) - operand_start}, + {slice->slice_starts(0) - operand_start + 1}, + slice->slice_strides()))); } - return replaced; + return true; } + + return false; +} + +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + +StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( + HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!IsUnstridedSlice(slice)) { + return false; + } + HloInstruction* reshape = slice->mutable_operand(0); + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + HloInstruction* new_slice_operand = reshape->mutable_operand(0); + int64 slice_rank = ShapeUtil::Rank(slice->shape()); + std::vector sliced_dims; + for (int64 i = 0; i < slice_rank; ++i) { + if (slice->slice_starts(i) != 0 || + slice->slice_limits(i) != reshape->shape().dimensions(i)) { + sliced_dims.push_back(i); + } + } + + if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && + slice->slice_starts(0) == 0) { + const Shape& new_slice_shape = new_slice_operand->shape(); + const int64 rank = ShapeUtil::Rank(new_slice_shape); + std::vector new_slice_starts(rank, 0); + std::vector new_slice_stides(rank, 1); + std::vector new_slice_limits(new_slice_shape.dimensions().begin(), + new_slice_shape.dimensions().end()); + int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); + for (int64 i = rank - 1; i >= 0; --i) { + if (slice_elements >= new_slice_limits[i]) { + if (slice_elements % new_slice_limits[i] != 0) { + return false; + } + slice_elements /= new_slice_limits[i]; + } else { + new_slice_limits[i] = slice_elements; + slice_elements = 1; + } + } + HloInstruction* new_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(new_slice_shape.element_type(), + new_slice_limits), + new_slice_operand, new_slice_starts, new_slice_limits, + new_slice_stides)); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); + return true; + } + return false; } Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { @@ -1888,12 +2094,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } - auto is_unstrided_slice = [](const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); - }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && - is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { HloInstruction* operand_slice = slice->mutable_operand(0); std::vector new_slice_starts = slice->slice_starts(); std::vector new_slice_limits = slice->slice_limits(); @@ -1907,11 +2109,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_slice_starts, new_slice_limits, slice->slice_strides())); } - TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifySliceOfPad(slice)); + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); if (replaced) { return Status::OK(); } + TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + if (replaced) { + return Status::OK(); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 7b3e957fbcf9f4628c4aeb0c323d50d3ed36a4f2..e4c4da1b0e7aef0e3476e4d232e410da25794e13 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -54,10 +53,11 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -67,18 +67,19 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { + auto m = CreateNewVerifiedModule(); Shape r0s32 = ShapeUtil::MakeShape(S32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -88,12 +89,12 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } @@ -166,6 +167,7 @@ TEST_F(AlgebraicSimplifierTest, SelectIdentical) { // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( @@ -180,7 +182,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); HloInstruction* param = builder.AddInstruction( @@ -193,17 +195,18 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reduce(param, zero)); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -213,18 +216,19 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Constant())); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -239,17 +243,18 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -261,17 +266,18 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloComputation* add_computation = nullptr; @@ -284,7 +290,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); HloInstruction* param0 = builder.AddInstruction( @@ -297,17 +303,18 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateBroadcast(r2f32, zero, {}))}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -319,64 +326,68 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14, 3.14, 4}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); } // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -386,18 +397,19 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A - Const is canonicalized to A + (-Const). TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -407,18 +419,19 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); } // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -432,14 +445,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -447,6 +460,7 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { // Test that A/(B/C) is simplified to (A*C)/B. TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -460,14 +474,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -475,6 +489,7 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -492,7 +507,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -500,7 +515,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -509,6 +524,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { // Test that A/exp(B) is simplified to A*exp(-B). TEST_F(AlgebraicSimplifierTest, DivOfExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -520,14 +536,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -535,6 +551,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { // Test that A/pow(B,C) is simplified to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfPower) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -548,14 +565,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -564,6 +581,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -577,14 +595,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -592,6 +610,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { // A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -602,11 +621,11 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Constant())); @@ -614,6 +633,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -627,10 +647,10 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Power(base, op::Multiply(exp1, exp2))); } @@ -638,6 +658,7 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + auto m = CreateNewVerifiedModule(); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -651,14 +672,15 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -668,18 +690,19 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A/1 is simplified to A for an array. TEST_F(AlgebraicSimplifierTest, DivOneArray) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -689,18 +712,19 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that complex(real(c), imag(c)) is simplified to c. TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); HloComputation::Builder builder(TestName()); @@ -713,18 +737,19 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -737,18 +762,19 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that imag(complex(r,i)) is simplified to i. TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -761,18 +787,19 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -788,18 +815,19 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } // Test that exp(A)/exp(B) is simplified to exp(A-B) TEST_F(AlgebraicSimplifierTest, ExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -813,14 +841,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -828,6 +856,7 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { // Test that exp(A)*exp(B) is simplified to exp(A+B) TEST_F(AlgebraicSimplifierTest, ExpMul) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -841,14 +870,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -856,6 +885,7 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { // Test that pow(exp(A), B) is simplified to exp(A*B) TEST_F(AlgebraicSimplifierTest, PowExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -867,14 +897,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -882,6 +912,7 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { // Test that ln(pow(A, B)) is simplified to ln(A)*B TEST_F(AlgebraicSimplifierTest, LnPow) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -893,14 +924,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -908,6 +939,7 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -917,19 +949,20 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B TEST_F(AlgebraicSimplifierTest, LnExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -945,14 +978,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -960,6 +993,7 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { // Test that pow(A, 0) where A is a scalar is simplified to the scalar // constant 1. TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -969,13 +1003,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -984,6 +1018,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). TEST_F(AlgebraicSimplifierTest, Pow0Vector) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -993,13 +1028,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -1012,6 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { // Test that pow(A, 1) is simplified to A. TEST_F(AlgebraicSimplifierTest, Pow1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1021,19 +1057,20 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. TEST_F(AlgebraicSimplifierTest, Pow2) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1043,19 +1080,20 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } // Test that pow(A, -1) is simplified to 1/A. TEST_F(AlgebraicSimplifierTest, PowNegative1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1065,13 +1103,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -1081,6 +1119,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { } TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); @@ -1113,17 +1152,18 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1148,24 +1188,25 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1182,17 +1223,18 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); - module().AddEntryComputation(builder.Build()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + m->AddEntryComputation(builder.Build()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Pad(param, op::Constant())); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Broadcast(op::Constant())); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -1206,39 +1248,41 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - module().AddEntryComputation(std::move(computation)); + m->AddEntryComputation(std::move(computation)); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), op); + EXPECT_THAT(m->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } // Test that copies are removed. TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1246,18 +1290,19 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1268,24 +1313,25 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param)); AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. EXPECT_THAT(computation->root_instruction(), op::Copy(param)); AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1293,19 +1339,20 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } // Test that empty operands of concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1322,7 +1369,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -1330,7 +1377,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -1338,6 +1385,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { // Test that reduce of concat is simplified. TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r3f32 = ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); @@ -1363,7 +1411,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); @@ -1373,11 +1421,11 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { builder.AddInstruction(HloInstruction::CreateReduce( reduce_shape, Concatenate, zero, {1, 2}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -1387,6 +1435,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1401,20 +1450,21 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that concat with a scalar broadcast becomes a pad. TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); @@ -1427,17 +1477,18 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } // Test that a simplification which changes layouts is not performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1445,7 +1496,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1455,7 +1506,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -1464,6 +1515,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { // Test that a simplification which preserves layouts is performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1471,7 +1523,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -1481,7 +1533,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1490,6 +1542,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { // Test that a reshape which could be replaced with a bitcast is not if // add_bitcasts is false. TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1502,13 +1555,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1516,6 +1569,7 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { // Test transforming reshapes and transposes of rng. TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); @@ -1532,11 +1586,11 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { ShapeUtil::MakeShape(F32, {4}), transpose)) ->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. @@ -1547,6 +1601,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1578,7 +1633,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, @@ -1586,7 +1641,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - simplifier.Run(&module()).ValueOrDie(); + simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1597,6 +1652,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1613,13 +1669,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1637,11 +1694,12 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1655,19 +1713,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1681,19 +1740,20 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1706,19 +1766,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1733,18 +1794,19 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1757,13 +1819,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), @@ -1772,6 +1834,7 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {5}), "param0")); @@ -1780,20 +1843,21 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } // Test merging broadcast and reshape. TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); @@ -1802,19 +1866,20 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1823,20 +1888,21 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1845,14 +1911,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1860,6 +1926,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1868,14 +1935,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = @@ -1885,6 +1952,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1893,33 +1961,34 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction(HloInstruction::CreateIota( ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_TRUE( @@ -1927,18 +1996,19 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { } TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); auto result_shape = iota->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Iota()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -1948,37 +2018,39 @@ TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); EXPECT_EQ(Cast(computation->root_instruction()) @@ -1987,19 +2059,20 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Iota()); const int64 iota_dim = @@ -2009,19 +2082,20 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); } @@ -2043,14 +2117,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2076,7 +2150,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2095,7 +2169,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -2110,14 +2184,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2133,14 +2207,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2162,14 +2236,14 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, /*start_indices=*/{2, 3}, /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); @@ -2178,6 +2252,57 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + const int64 dim2 = 13; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape, + /*start_indices=*/{0, 0, 0}, + /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Slice(param))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {3600, 512}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {960, 512}), original_reshape, + /*start_indices=*/{0, 0}, + /*limit_indices=*/{960, 512}, /*strides=*/{1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Reshape(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); @@ -2185,11 +2310,11 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } @@ -2207,15 +2332,191 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { builder.AddInstruction(HloInstruction::CreateSort( ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, keys, {values0, values1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values0, values1)); } +// Test that A && True is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that True && A is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A && False is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that False && A is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that A || True is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that True || A is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that A || False is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that False || A is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Used for TEST_Ps that test merging (or not) of a kPad instruction into a // convolution's Window. struct ConvPaddingTestcase { @@ -2337,15 +2638,15 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(), lhs_pad, filter, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2455,15 +2756,15 @@ TEST_P(ConvFilterPaddingTest, DoIt) { input, rhs_pad, /*feature_group_count=*/1, window, dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); @@ -2604,7 +2905,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2724,7 +3025,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -2734,10 +3035,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2763,7 +3064,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -2772,7 +3073,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2782,7 +3083,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2864,7 +3165,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2954,12 +3255,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2970,6 +3271,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { // Dots add computations to the parent module. Test that, when the HloModule's // computations are updated, then iterator invalidation doesn't occur // when running on subsequent computations. + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); HloComputation::Builder builder(TestName() + ".Dot"); HloInstruction* x = @@ -2991,15 +3293,16 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - module().AddEmbeddedComputation(std::move(dot_computation)); - module().AddEntryComputation(call_builder.Build()); + m->AddEmbeddedComputation(std::move(dot_computation)); + m->AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -3008,11 +3311,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -3021,6 +3324,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { // of its input equals the size of its output. In this case, the dynamic slice // is equal to its input. TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3032,10 +3336,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -3043,6 +3347,7 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -3065,16 +3370,17 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( @@ -3085,12 +3391,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { builder.AddInstruction( HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); EXPECT_THAT(root->dimensions(), ElementsAre(2)); @@ -3098,6 +3404,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); @@ -3111,12 +3418,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); @@ -3124,6 +3431,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* iota = @@ -3131,12 +3439,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3144,6 +3452,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); HloInstruction* iota = @@ -3152,12 +3461,12 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Iota()); EXPECT_EQ(Cast(root)->iota_dimension(), 2); @@ -3174,9 +3483,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3196,9 +3504,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3218,9 +3525,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3238,9 +3544,8 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]} } )"; - TF_ASSERT_OK_AND_ASSIGN( - auto module, - HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); @@ -3249,6 +3554,92 @@ TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { EXPECT_THAT(root, op::Parameter()); } +TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[2:3]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(1)); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Slice(op::Parameter(2))); + EXPECT_EQ(root->slice_starts(0), 1); + EXPECT_EQ(root->slice_limits(0), 2); +} + +TEST_F(AlgebraicSimplifierTest, NegateNegate) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + neg.0 = f32[2] negate(param.0) + ROOT neg.1 = f32[2] negate(neg.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + +TEST_F(AlgebraicSimplifierTest, NotNot) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = pred[2] parameter(0) + not.0 = pred[2] not(param.0) + ROOT not.1 = pred[2] not(not.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter(0)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -3278,6 +3669,7 @@ class PadReduceWindowEffectiveBroadcastTest PadReduceWindowEffectiveBroadcastCase> {}; TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + auto m = CreateNewVerifiedModule(); const auto& param = GetParam(); // a and b are parallel bounds we can either turn into a B F S0 S1 or @@ -3326,7 +3718,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Window window = window_util::MakeWindow( @@ -3340,10 +3732,10 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3392,6 +3784,7 @@ class DotStrengthReductionTest public ::testing::WithParamInterface< ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + auto module = CreateNewVerifiedModule(); int m, k, n; bool transpose_lhs, transpose_rhs; PrimitiveType element_type; @@ -3421,10 +3814,10 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -3452,7 +3845,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // Test that we transform @@ -3460,6 +3853,7 @@ class DotOfConcatSimplificationTest // to // add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3498,10 +3892,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( @@ -3519,6 +3913,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { // to // add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3562,10 +3957,10 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3590,6 +3985,7 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { // Test that DynamicUpdateSlice update param with any dimension equal to zero // gets removed. TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); HloInstruction* const operand = builder.AddInstruction( @@ -3602,11 +3998,11 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -3625,7 +4021,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloVerifiedTestBase, + : public HloTestBase, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -3634,6 +4030,7 @@ class DotOfGatherSimplificationTest // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3680,10 +4077,10 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3704,6 +4101,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3750,10 +4148,10 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a645f98220ec445bb9bbdf2b9b842109..52ec1a794c5e9f4452a4bf2b648f453d8acfe976 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { @@ -38,11 +37,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -61,11 +61,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -84,11 +85,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -107,11 +109,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -130,11 +133,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -153,11 +157,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f7ac8f5482908af104554a1cf812370b9098cda7..08cf8026177d77ff98cca5e5d168ac3194936b35 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloVerifiedTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -59,14 +59,14 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { param0, param1, param2, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -101,14 +101,14 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { param1, param2, param3, param4, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie()); - for (auto* instruction : module().entry_computation()->instructions()) { + for (auto* instruction : m->entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 5f93740887aa7e61458990992fe0573883ff056d..4ce351acc2c359773e618da70360c96faf5ca379 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { +class BFloat16ConversionFoldingTest : public HloTestBase { protected: BFloat16ConversionFoldingTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -103,10 +103,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -138,10 +138,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -173,10 +173,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -203,10 +203,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { HloInstruction* convert1 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -216,7 +216,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("add"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -252,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cb075a5e38a5ea9db2ceb432b2b59f8db5e2e640..9f97d18c565c7915b9f9346f0c6330cdc3c707e9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloVerifiedTestBase { +class BFloat16NormalizationTest : public HloTestBase { protected: BFloat16NormalizationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -106,10 +106,10 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module)); + EXPECT_FALSE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -134,10 +134,10 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { HloInstruction* mul1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -164,10 +164,10 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { HloInstruction* sub1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -191,7 +191,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, reduce_comp_param0, reduce_comp_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(reduce_comp_builder.Build()); @@ -205,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -233,7 +233,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -263,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -272,7 +272,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -290,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -299,7 +299,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -314,7 +314,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); @@ -342,10 +342,10 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 0af71eaac96fca366e45430788e769c618f86bb5..5be7141aae423adb4fe2f39262e463ff25ae8234 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloVerifiedTestBase { +class BFloat16PropagationTest : public HloTestBase { protected: BFloat16PropagationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. @@ -121,10 +121,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,6 +136,62 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) { + auto module = CreateNewVerifiedModule(); + + auto sub_builder = HloComputation::Builder("max"); + HloInstruction* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); + HloInstruction* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); + sub_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1)); + auto max_computation = module->AddEmbeddedComputation(sub_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + HloInstruction* rw = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + shape, add, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + window, max_computation)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0})); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(rw)); +} + // Tests that side-effecting all-reduce should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { auto module = CreateNewVerifiedModule(); @@ -186,10 +242,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -242,10 +298,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -281,10 +337,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction* dot = builder.AddInstruction( CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -310,10 +366,10 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -321,7 +377,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { // Tests that BF16 is propagated properly through fused computations. TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -356,7 +412,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -369,7 +425,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { // Tests that changes to BF16 that cannot be propagated outside a fusion are // discarded. TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -393,7 +449,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -408,7 +464,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { // (BF16, BF16) fusion_computation(F32 a, F32 b) // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -439,7 +495,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -458,7 +514,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { // on_true and on_false must match, so that as long as one of them is F32, the // other must be F32 as well. TEST_F(BFloat16PropagationTest, SelectOverTuples) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -489,7 +545,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -502,7 +558,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { // Tests that BF16 is propagated properly through a while computation with // non-tuple input/output. TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -545,7 +601,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -561,7 +617,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { // made to the while body and thus the fusion node inside it. TEST_F(BFloat16PropagationTest, ConditionPreventsPropagationForFusionInsideWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -610,7 +666,7 @@ TEST_F(BFloat16PropagationTest, auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -622,7 +678,7 @@ TEST_F(BFloat16PropagationTest, // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -690,7 +746,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -709,7 +765,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // Tests that BF16 is not propagated through multiple whiles that invoke the // same computation as long as one while prevents the propagation. TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -820,7 +876,7 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -859,10 +915,10 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( bf16_shape, HloOpcode::kAdd, convert0, convert1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -895,10 +951,10 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -941,10 +997,10 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 5b48f10505e78c035608d4c575501e4623218987..2b9502f63a821f3675ddfb506f41bb2390cf4136 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -107,6 +108,21 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kSelect: case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + HloComputation* reduce_comp = hlo.called_computations()[0]; + for (HloInstruction* inst : reduce_comp->instructions()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) { + return false; + } + } + } + return true; + } default: break; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index ee4e5942731110e16c8396a824e6dbd19c9df607..40c012a5e4214f00dbeaca4e8cbfaa668089c6e8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -641,7 +641,7 @@ Status BufferAssignment::ComputeSummaryStats() { bool schedule_complete = true; for (const auto& computation : module_->computations()) { if (!computation->IsFusionComputation()) { - const std::vector* sequence = + const HloInstructionSequence* sequence = liveness_->hlo_ordering().SequentialOrder(*computation); if (sequence == nullptr) { schedule_complete = false; @@ -1180,7 +1180,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); schedule.set_sequence(computation, *instruction_sequence); @@ -1215,7 +1215,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); @@ -1230,7 +1230,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, - HloInstructionSequence(*instruction_sequence), + *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 327211d3efd24177a28cc2d08dc3c4fbf2fbaff9..b1fc50cb1881241a0a53b024b06342308cabdd62 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloVerifiedTestBase { +class BufferAssignmentTest : public HloTestBase { protected: ~BufferAssignmentTest() override {} @@ -334,16 +334,16 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -358,17 +358,17 @@ TEST_F(BufferAssignmentTest, BufferForConst) { LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -387,10 +387,10 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate, param0, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -410,10 +410,10 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -439,10 +439,10 @@ TEST_F(BufferAssignmentTest, Basic) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -538,7 +538,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -553,7 +553,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -599,7 +599,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -622,7 +622,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -671,10 +671,10 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -706,7 +706,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // param0[100x10] ---> (map x+1) // // Builds the map function. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto map_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); auto inner_last = map_computation->root_instruction(); @@ -725,7 +725,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -761,7 +761,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // out-of-order reductions could overwrite an element before a use.) // // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); @@ -784,7 +784,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -812,7 +812,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // const4[f32[4]] --- tuple --- while[condition, body] // // Builds the nested condition and body. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto condition_computation = module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); auto body_computation = @@ -840,7 +840,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -878,7 +878,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { } TEST_F(BufferAssignmentTest, ExampleConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto true_computation = module->AddEmbeddedComputation( BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); auto false_computation = module->AddEmbeddedComputation( @@ -905,7 +905,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -941,9 +941,9 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto neg = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -970,9 +970,9 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1003,9 +1003,9 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1040,9 +1040,9 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1075,9 +1075,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1107,9 +1107,9 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1145,9 +1145,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1160,7 +1160,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Verify that buffers for embedded computations are properly marked as // thread-local and that embedded parameters are not marked as // is_entry_computation_parameter. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto vec_shape = ShapeUtil::MakeShape(F32, {42}); auto scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1191,7 +1191,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Allocations for the map computation should be thread-local and not // live-out. @@ -1238,9 +1238,9 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { ShapeUtil::MakeShape(S32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1274,9 +1274,9 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1318,9 +1318,9 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1332,9 +1332,9 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), ShapeUtil::MakeShape(S32, {101})}), /*operands=*/{}, /*custom_call_target=*/"foo_function")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1347,7 +1347,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { TEST_F(BufferAssignmentTest, TupleCallAsOutput) { // Test a computation which returns a tuple call value. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1365,7 +1365,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1388,7 +1388,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { // B: call(C, param) // C: call(D, param) // D: param - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1427,7 +1427,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1461,9 +1461,9 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto bitcast = builder.AddInstruction( HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1488,9 +1488,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, pred_param, tuple_param0, tuple_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1526,9 +1526,9 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto copy = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape, HloOpcode::kCopy, tuple_element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1568,9 +1568,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module, /*alignment=*/1); + auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1589,7 +1589,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module, /*alignment=*/64); + assignment = RunBufferAssignment(module.get(), /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1632,10 +1632,10 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = @@ -1673,11 +1673,11 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module, {param, log, rev, neg, concat, root}); + module.get(), {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1698,7 +1698,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { } TEST_F(BufferAssignmentTest, PeakBuffersWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); HloComputation* condition; { @@ -1733,7 +1733,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1783,13 +1783,13 @@ ENTRY main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* constant_1 = - module().entry_computation()->GetInstructionWithName("constant.1.1"); + m->entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module().entry_computation()->GetInstructionWithName("constant.1.2"); + m->entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); { const BufferAllocation& allocation_for_const_1 = @@ -1818,7 +1818,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloVerifiedTestBase { +class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1878,7 +1878,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1917,8 +1917,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1974,20 +1974,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* param = - module().entry_computation()->parameter_instruction(0); + m->entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1995,7 +1994,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2042,20 +2041,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* constant = - module().entry_computation()->GetInstructionWithName("constant.42"); + m->entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -2063,7 +2061,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2121,7 +2119,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto token = builder.AddInstruction(HloInstruction::CreateToken()); @@ -2156,7 +2154,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module).status()); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry @@ -2175,12 +2173,12 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2202,7 +2200,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2234,8 +2232,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2247,7 +2245,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -2277,13 +2275,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); } - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2308,13 +2306,14 @@ ENTRY Main { )"; HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_text, config); + config.set_debug_options(GetDebugOptionsFromFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(hlo_text, config)); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); - HloComputation* main = module().entry_computation(); - HloComputation* callee = module().GetComputationWithName("Callee"); + HloComputation* main = m->entry_computation(); + HloComputation* callee = m->GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2338,7 +2337,7 @@ ENTRY Main { } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2385,11 +2384,11 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module); + RunCopyInsertion(module.get()); HloSchedule schedule = ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); @@ -2407,18 +2406,18 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run(module, - absl::make_unique(schedule), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2462,8 +2461,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 17e50905059ad2c92784d14132c1cb1f46f35ade..aeee543e8435200915ab992e2aa146a3c17646d5 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -514,7 +514,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 34f3f914d593bc603c4964663f9cafb70a136fd3..a3ac2568b0f3eec8556a42dbe3c2c64bd8564468 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloVerifiedTestBase { +class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -93,10 +93,10 @@ class CallGraphTest : public HloVerifiedTestBase { TEST_F(CallGraphTest, SingletonComputation) { // Test the call graph of a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -112,13 +112,13 @@ TEST_F(CallGraphTest, SingletonComputation) { TEST_F(CallGraphTest, UnreachableComputation) { // Test the call graph of a module with an entry computation and an // unreachable computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -134,13 +134,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { TEST_F(CallGraphTest, ParallelComputation) { // Test a call graph of a module with an entry computation which calls another // computation in a parallel context via kMap. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* map_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -163,13 +163,13 @@ TEST_F(CallGraphTest, ParallelComputation) { TEST_F(CallGraphTest, SequentialComputations) { // Test a call graph of a module with an entry computation which calls another // computation in a sequential context via kCall. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* called_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -196,7 +196,7 @@ TEST_F(CallGraphTest, SequentialComputations) { TEST_F(CallGraphTest, ContextBothComputations) { // Test a call graph of a module with an entry computation which calls another // computation in both a parallel and sequential context. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -239,7 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { TEST_F(CallGraphTest, ComputationWithConditional) { // Test a call graph of a module with a conditional. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* true_computation = module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); HloComputation* false_computation = @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(3, call_graph->nodes().size()); @@ -298,7 +298,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -418,7 +418,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -479,10 +479,10 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { TEST_F(CallGraphTest, VisitSingletonComputation) { // Test the call graph visitor with a call graph with a single node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -494,12 +494,12 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { TEST_F(CallGraphTest, VisitUnreachableComputation) { // Test the call graph visitor with a call graph with an unreachable node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // Test visitation of only reachable nodes. { @@ -531,9 +531,9 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e6b566543594a86eb5369ee9b7440f62618f6c5a..0b6e323f75c7a5dae127e20d2a4b92a83a72df3b 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloVerifiedTestBase; +using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -51,7 +51,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloInstruction* one = inner.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -79,7 +79,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // returns false should be identical to just returning false). TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create a lambda that calls a function that returns the false predicate. // Note we also use this lambda twice by reference, just to make the test a @@ -107,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -120,7 +120,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // whole pass. TEST_F(CallInlinerTest, InlineWithoutRunningPass) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { const Shape f32 = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( @@ -163,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..2662fe46705f4936ce0d654df0943e7d30890ebe --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_cache.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = counter++; + return id; +} + +} // namespace + +ExecutionHandle CompilationCache::Insert( + std::unique_ptr executable) { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = GetUniqueId(); + VLOG(2) << "inserting cache key: " << key; + CHECK_EQ(cache_.count(key), 0); + cache_.emplace(key, std::move(executable)); + + ExecutionHandle handle; + handle.set_handle(key); + return handle; +} + +StatusOr> CompilationCache::LookUp( + const ExecutionHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = handle.handle(); + VLOG(2) << "looking up cache key: " << key; + if (cache_.count(key) == 0) { + VLOG(2) << "cache key not found: " << key; + return InvalidArgumentStrCat("can not find executable with handle ", key); + } else { + auto& result = cache_.at(key); + VLOG(2) << "hit executable: " << result->module().name(); + return result; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..5f94def509d4d4a8950272cb498af5056a698ce0 --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A cache which stores Executables indexed by computation handle and version. +// +// TODO(b/119042872): Provide mechanism for removing computations from the +// compilation cache. +class CompilationCache { + public: + CompilationCache() {} + + ExecutionHandle Insert(std::unique_ptr executable); + + // Lookup the Executable for the specified handle in the cache. Return a + // shared_ptr to the Executable if it exists in the cache. + StatusOr> LookUp( + const ExecutionHandle& handle) const; + + protected: + mutable tensorflow::mutex mutex_; + + using CacheKey = int64; + + absl::flat_hash_map> cache_ + GUARDED_BY(mutex_); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 6d67f970020d278cc7bf61b56350200d3e5cb926..67132274c0dcbfda831c79836d052bb51b753ec7 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 80c630c6201503d88a690f04a88f6fca6f3a438a..8f08c244908efb823b3870c19bdc3491fa87d44f 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -110,6 +110,6 @@ Compiler::GetPlatformCompilers() { } AotCompilationOptions::AotCompilationOptions() - : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + : debug_options_(GetDebugOptionsFromFlags()) {} } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167d47af3c92ed35fa52594fa5da1e4af..289eb6d90239a72ecc0f3312a7e0e8453f946858 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -37,7 +37,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalSimplifierTest : public HloVerifiedTestBase { +class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); @@ -96,25 +96,28 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { - HloComputation* computation = MakeConditional(&module()); - ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); + ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Constant())); } TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* true_op = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -125,11 +128,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -138,18 +142,19 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); false_computation->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2..7f7f1503a099b3a67ed22cb5978c01da6cf8ba88 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -51,7 +51,8 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(HloComputation* computation, + bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -59,18 +60,24 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(HloComputation* computation, + bool canonicalize_depthwise_filter = false) + : computation_(computation), + filter_expansion_(!canonicalize_depthwise_filter) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Whether filter expansion is required. + bool filter_expansion_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(HloComputation* computation, + bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -190,9 +197,49 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { HloInstruction* filter_mask = GetExpandedFilterMask( filter->shape(), input_feature_dim, output_feature_dim, group_count, add); HloInstruction* expanded_filter; - // We want to repeat 'filter' in the 'input_feature_dim' dimension - // 'group_count' times. + if (group_size == 1) { + bool depthwise_separable = + (group_count == filter->shape().dimensions(output_feature_dim)); + // If the code generator handles depthwise separable convolutions + // inherently, then no filter expansion is needed. + if (!filter_expansion_ && depthwise_separable) { + const int64 old_kernel_input_feature_dimension = + dim_numbers.kernel_input_feature_dimension(); + const int64 old_kernel_output_feature_dimension = + dim_numbers.kernel_output_feature_dimension(); + + // For depthwise convolutions, we want the kernel input feature dimension + // to be smaller than the output feature dimension. If that's not the + // case, we swap the dimensions. + if (old_kernel_input_feature_dimension > + old_kernel_output_feature_dimension) { + Shape reshaped_filter_shape = filter->shape(); + auto& dimensions = *reshaped_filter_shape.mutable_dimensions(); + std::swap(dimensions[old_kernel_input_feature_dimension], + dimensions[old_kernel_output_feature_dimension]); + + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + dim_numbers.set_kernel_input_feature_dimension( + old_kernel_output_feature_dimension); + + dim_numbers.set_kernel_output_feature_dimension( + old_kernel_input_feature_dimension); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), + reshaped_filter, group_count, convolution->window(), dim_numbers, + convolution->precision_config()); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + } + return Status::OK(); + } + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. Shape reshaped_filter_shape = ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); auto reshaped_filter = @@ -237,7 +284,7 @@ StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(comp, filter_expansion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7..cb6bc04c00a2ff10f970da2a07fb540a561dad5a 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -27,7 +27,8 @@ namespace xla { // convolutions with feature_group_count = 1. class ConvolutionFeatureGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter() {} + ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) + : filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { return "convolution-feature-group-converter"; @@ -36,6 +37,9 @@ class ConvolutionFeatureGroupConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + // Tells whether filter expansion is required. + bool filter_expansion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4533ebb99bbba854a029fb8a9a1e31b023be720d..7446bc7cc11553984dcf1cea00c58072d2cbf0f0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -403,7 +403,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(CreateNewModule()) {} + WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1362,7 +1362,7 @@ TEST_F(CopyInsertionTest, CrossingParameters) { // | / \ | // | / \| // (p1 , p0) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1395,7 +1395,7 @@ TEST_F(CopyInsertionTest, ParametersAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1428,7 +1428,7 @@ TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1461,7 +1461,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { // | | // | | // (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1496,7 +1496,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { // | | | // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1534,7 +1534,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { // | Add----+ // | | | // +-- (p0 , p1) - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1569,7 +1569,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1632,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1693,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1783,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -1896,7 +1896,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); @@ -1936,7 +1936,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); @@ -2003,7 +2003,7 @@ std::unique_ptr MakeBenchmarkWhileBody( void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { tensorflow::testing::StopTiming(); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); CopyInsertion copy_insertion; const Shape element_shape = ShapeUtil::MakeShape(F32, {}); std::vector tuple_params(num_tuple_inputs); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 36e25cbe678e03f511934eb00af8c3834de2c63e..2763d18121a0c1328ea0c11d825476923ae2b15d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -824,7 +824,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -846,7 +845,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -887,7 +885,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -961,17 +958,16 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -997,7 +993,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 2083f440fdd971db1b675d005664d25e6de53dbe..c58175428fea6a2d38253c35de598b99a4281bf1 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloVerifiedTestBase { +class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -87,7 +87,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -150,7 +150,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index c9fb34be1cd582c71618c770c892058c233c571a..c085f85fb73e98e4c7ba15af8db8bb19c2499f5f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloVerifiedTestBase { +class CpuCopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -65,7 +65,7 @@ class CpuCopyInsertionTest : public HloVerifiedTestBase { TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). Each constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 3); @@ -103,7 +103,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { // Test a kCall instruction which calls a computation which produces a three // element tuple: one is a constant, one is a parameter, and one is produced // in the computation. The constant and parameter should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index e6b6fcdf684eadb3702e490bbe24dbb7b3b52ec7..9cbfb88834bf51f4df54e97efe6cd7bf88b12334 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloVerifiedTestBase { +class CpuHloSupportCheckerTest : public HloTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(CpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 7d99b914d4f5e5d27722bcd098d2ae0c54a36a23..c95a514ca04bee1fb4c03ee21510eb8da3122081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -58,7 +58,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -98,7 +98,7 @@ TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -119,7 +119,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -138,7 +138,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -157,7 +157,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -350,7 +350,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -392,7 +392,7 @@ TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -466,7 +466,7 @@ TEST_F(OpcodeFusionTest, Slice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -641,7 +641,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { builder.AddInstruction( HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -670,7 +670,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { builder.AddInstruction(HloInstruction::CreateBinary( large_shape, HloOpcode::kAdd, small_exp, large_param)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 97659b88a7974d7caf91ab0d4741f3635e4dae4a..2cd52e4a18a4524365393db5f658a982d83a7632 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index d6968323f337d83e41b5e031cc49fab5b6a17b21..620c45fa391e69ef88269d44709404e6f71b30cb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1536,7 +1536,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs, rhs}, {lhs->getType()}, b); @@ -1551,7 +1552,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs, rhs}, {lhs->getType()}, b); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fad76338a57cd9eb21d9469ca8552efa8ea0129b..f0b65046c14ccec5336abf7c4d05d1d755f783bd 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class ParallelTaskAssignmentTest : public HloVerifiedTestBase { +class ParallelTaskAssignmentTest : public HloTestBase { protected: const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = cpu::CpuExecutable::ShapeSizeBytes; @@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { + : HloTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -57,8 +57,9 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -84,8 +85,9 @@ TEST_F(ParallelTaskAssignmentTest, } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -100,8 +102,9 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -116,8 +119,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 1a3d82de954318368d61e3feeb0345dc592dcd8b..7d8e51f909e3db699b745f94a6c625407bc4a6e3 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloVerifiedTestBase { +class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloVerifiedTestBase { +class ShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { +class RandomShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4b129c95d46d8b5a119e5d23eef387daf7863cce..382dfd0d99df87bbadfe541ddaa32cd6da8e8068 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,7 +48,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 18ee25ba9158c28baaf01492c290638b9673f1ec..691b3c7bee26e84edbef18a4ac10a9cafd29c61a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 00a7aa2ad2f6bac4877302296ccb76222557535c..d201a151d7a9edb86a0de15819ea99f95a9c4d28 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 1deb412064b02988a8d4a6d726969c948d354d47..04a81dfd35f459ff1fdb3181dc8fc65c62a37d4f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloVerifiedTestBase { +class CpuFusionTest : public HloTestBase { protected: CpuFusionTest() {} @@ -57,11 +57,11 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -104,11 +104,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -131,7 +131,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal.shape(); @@ -183,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -250,12 +250,12 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); // Create computation and module. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -310,11 +310,11 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate1, negate2, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index a434c04a980b9b3cd849792b97a0d9e965ba09f2..773336c7a92f808f0c6370c7353e780b1471470f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index b35fd9dad877c319c3d0110c96a00aeefa78769e..f5419b7063bea6d1f5d24fde0a22e829413b8d93 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef..64fb50318394918b277fd717994f5366d762ac36 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -18,19 +18,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloVerifiedTestBase { +class DefuserTest : public HloTestBase { protected: // Returns the number of fusion instructions in the module. - int FusionCount() { + int FusionCount(const HloModule* m) { int count = 0; - for (HloComputation* computation : module().computations()) { + for (HloComputation* computation : m->computations()) { if (computation->IsFusionComputation()) { count++; } @@ -43,6 +43,7 @@ class DefuserTest : public HloVerifiedTestBase { }; TEST_F(DefuserTest, NoFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -51,13 +52,14 @@ TEST_F(DefuserTest, NoFusionInstruction) { builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - module().AddEntryComputation(builder.Build()); - EXPECT_EQ(0, FusionCount()); + m->AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount(m.get())); - EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie()); } TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -66,21 +68,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -91,21 +94,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add(op::Parameter(), op::Parameter()))); } TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -128,22 +132,23 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction( {add2, constant, div, mul, sub, negate, add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, MultipleFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -166,7 +171,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add2, constant, div, mul}, HloInstruction::FusionKind::kLoop); computation->CreateFusionInstruction({sub, negate, add}, @@ -174,15 +179,16 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, NestedFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -193,7 +199,7 @@ TEST_F(DefuserTest, NestedFusionInstructions) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); auto outer_fusion = computation->CreateFusionInstruction( {negate, add}, HloInstruction::FusionKind::kLoop); HloInstruction* fused_negate = outer_fusion->fused_expression_root(); @@ -203,9 +209,9 @@ TEST_F(DefuserTest, NestedFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 4159aa281fa2b66d310d7c135f123a5a3bb83270..d6371283221b63b30f968929fe2807eae3f22df0 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -108,6 +108,7 @@ class DfsHloVisitorBase { virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 4cd10ab06cd3b804406607212d3f3c316d6cff95..e57184f639f4f2c618b980a5082381f4b9c28b19 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -203,6 +203,9 @@ class DfsHloVisitorWithDefaultBase Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } + Status HandleGetDimensionSize(HloInstructionPtr get_size) override { + return DefaultAction(get_size); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 515267edd7caf42e04ebe638b99006db8967ea30..f98c943669be8c14d245896b91cee3eee1e47429 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1815,8 +1815,6 @@ StatusOr ElementalIrEmitter::EmitElementalGather( // Clamp the gather index so that the gather region fits in the operand. // gather_dim_component_extended_inbound = // clamp(gather_dim_component_extended, 0, largest_valid_start_index); - - // TODO(b/111078873): This is implementation defined behavior. bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); auto gather_dim_component_extended_inbound = EmitIntegralMin( index.GetConstantWithIndexType(largest_valid_start_index), diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb..10b8c01ff1383658fcfb2271c177ba54347f985a 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 3a6780f2a67f230cae626ea00cfbf93b4e60d968..45f620f3f33eee41eefa9ddfdfb166a5ba76caef 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8..8eeb930b48165a2e3c622581e05cb5f7063fa1fa 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloVerifiedTestBase { +class FlattenCallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -108,7 +108,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module); + std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -149,7 +149,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // Test corner case of a computation used as a body and a loop condition. TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation; { HloComputation::Builder builder(TestName() + ".cond"); @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -201,7 +201,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { // C // TEST_F(FlattenCallGraphTest, FlattenCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* c_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -224,7 +224,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { } TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* sub_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 1e8435fe542f2b65a11e256453cf911c5e6e833b..b1629616acd2bb715d5aa1a89286a38a45417d2c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -111,7 +111,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -463,7 +462,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -627,7 +626,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -849,7 +848,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -909,7 +907,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -1036,6 +1033,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc index fa3afa6a5d318c399dc38e8934199b5a1393669e..af9303a5b761b99705945f1c02303156e3f874de 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvPadForTensorCoresTest : public HloVerifiedTestBase {}; +class CudnnConvPadForTensorCoresTest : public HloTestBase {}; TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc index c46672c598b27670c56b3efa4775be8fea1fc6ac..4ce877f62a55c960765314670288ee626c5fc15b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -254,7 +254,7 @@ MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // TODO(b/119479517): Theoretically cuDNN supports grouped convolutions also // for the backward input convolution, but at least for now with version 7.1.4 // it is slower. This needs to be re-evaluated for future cuDNN versions. // Note that we already have the necessary code down below, the only thing to diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index 87a835f2504068548159ef32b276201c936fa385..443883a89f66a747def1049bc5afb53fec3c2409 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -34,11 +34,11 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -class CudnnConvRewriterTest : public HloVerifiedTestBase { +class CudnnConvRewriterTest : public HloTestBase { public: CudnnConvRewriterTest() - : HloVerifiedTestBase(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false) { + : HloTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -118,10 +118,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { metadata.set_op_name("foo"); conv->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -152,10 +152,10 @@ TEST_F(CudnnConvRewriterTest, activations, gradients, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -182,10 +182,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -212,10 +212,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -241,10 +241,10 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -292,10 +292,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { /*feature_group_count=*/1, conv_window, conv_dnums) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -338,10 +338,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -371,10 +371,10 @@ TEST_F(CudnnConvRewriterTest, default_conv_window_, tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -425,10 +425,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -475,10 +475,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -529,10 +529,10 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -584,10 +584,10 @@ TEST_F(CudnnConvRewriterTest, conv_window, tf_default_dnums_for_backward_input_) .ValueOrDie())); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); + EXPECT_TRUE(RunPass(module.get())); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -600,7 +600,8 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { constant_arr.FillIota(0); string constant_str = LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); - ParseAndVerifyModule(absl::StrFormat(R"( + + const string module_str = absl::StrFormat(R"( HloModule test ENTRY entry_computation { @@ -610,10 +611,12 @@ TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, dim_labels=bf01_01oi->bf01, feature_group_count=1 })", - constant_str)); - EXPECT_TRUE(RunPass(&module())); + constant_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); EXPECT_THAT( - module().entry_computation()->root_instruction(), + m->entry_computation()->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, op::Reverse(op::Constant())), 0)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 02a0d028c118aba23996f9b97d05443bb4a00c88..91609c730b6c0d666eb607fb42b918c0f8f250e5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -42,7 +42,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { // Only the entry computation can possibly be sequentially ordered, and only // if we've assigned all instructions to a single stream. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return &computation == module_->entry_computation() ? entry_sequence_.get() : nullptr; @@ -51,7 +51,7 @@ class GpuHloOrdering : public PredecessorHloOrdering { string ToString() const override { return ToStringHelper("GpuHloOrdering"); } private: - std::unique_ptr> entry_sequence_; + std::unique_ptr entry_sequence_; }; GpuHloOrdering::GpuHloOrdering( @@ -60,8 +60,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = absl::make_unique>( - thunk_launch_order); + entry_sequence_ = + absl::make_unique(thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering( for (auto* computation : module->computations()) { if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, + HloReachabilityMap::Build(computation)); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index b857fa775a76ec999b505a2a64332cc0c54cf00b..6d3aed15ebe7d925eda00a72177a03a2264a640c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,14 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloVerifiedTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: using HloVec = std::vector; @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloVerifiedTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 7d01eeb02567d710e9de089c7f29ffcc5f959f9a..b511155f85fb24adc1828cbef7f3fb60778ef7ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloVerifiedTestBase { +class GpuHloSupportCheckerTest : public HloTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(GpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { // Since verifier is reporting sparse layouts as errors, we should // use a regular HloModule instead of VerifiedHloModule to avoid // verifier errors being triggered in the destructor. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); Status status = checker().Run(module.get()).status(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4822b820f4e229336e2b26cfbd0097c8c31a50c8..8cc76c872c61634ca4344d8a8cdf8c6a75aea2ac 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 7f2b59810f0334b24a50fc83b85ab838002afd23..43f43b50e4a6478f343088194871cc9d380bd2d2 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 57e66f5a12cf54824c3139ce2fb32e7cf762b040..fb77bc4b8eb497d09014da96769b52aa606510af 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -41,7 +41,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), exp1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -61,7 +61,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), negate1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -80,7 +80,7 @@ TEST_F(InstructionFusionTest, HloInstruction* reshape2 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -99,7 +99,7 @@ TEST_F(InstructionFusionTest, HloInstruction* transpose2 = builder.AddInstruction( HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -723,7 +723,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) { op::Reduce(op::Broadcast(op::Parameter()), op::Constant())); } +TEST_F(InstructionFusionTest, FuseReverse) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY Reverse { + p0 = f32[50,96,1024]{2,1,0} parameter(0) + add = f32[50,96,1024]{2,1,0} add(p0, p0) + ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reverse(op::Add(op::Parameter(), op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 21e44e1e7d3fb7818e114b70025bfb85eacf786a..87b6cd640acc41074c40e1d397b9334b76029fd5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2197,9 +2197,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { } int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - auto index_type = b_.getInt64Ty(); + CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); + CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); // Naive C++ code for the outer loops: // @@ -2213,42 +2214,119 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // } // } // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - + // This follows the alternative representation of the algorithm described on + // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter + // + // Each mask specifies how to derive from one position in the array the + // position with which it should be compared (we calculate the xor of the + // position with the mask). + // As an optimization, we can move the 'mask' loop to inside the + // sorting/comparison loop if the comparisons happen within a small block of + // the array. To make this work, we collect all consecutive masks that are + // smaller than our chosen power of 2 tile size, and pass them to SortInPlace. + // Each thread then processes one tile of data. + + const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages); + + // If we cannot combine several xor masks together, we don't use tiling, so we + // calculate the standard launch dimensions for the shape. However we only + // need to iterate through ~half of the dimension to sort (rounded up to the + // next highest power of 2), because each iteration compares one pair of + // elements. + Shape standard_iteration_shape = keys_shape; + uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); + standard_iteration_shape.set_dimensions(dimension_to_sort, + standard_num_iterations_in_sort_dim); + LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( + standard_iteration_shape, ir_emitter_context_->device_description()); + + // Calculate the launch dimensions for the case where we use tiling. We split + // the dimension that should be sorted into tiles of size 'kTileSize'. This + // means we first need to round 'dimension_to_sort_bound' up to be a multiple + // of the tile size. + int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize); + Shape iteration_shape = keys_shape; + + // We iterate through the element pairs that should be compared. + uint64 num_iterations_in_sort_dim = rounded_bound / 2; + iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim); + uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape); + + // For correctness reasons we need exactly 'kTileSize' / 2 many threads per + // block. Each thread is responsible for copying exactly two adjacent elements + // into shared memory, and then does a comparison of two possibly different + // elements taken from shared memory. + const uint64 kThreadsPerBlock = kTileSize / 2; + + // Check whether we should use any tiling. We might not be able to use it if + // we have not enough threads, or not enough shared memory. Also it does not + // give a speedup if the tile size is < 128. + int64 total_shared_memory_needed = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + total_shared_memory_needed += + kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type()); + } + bool no_tiling = + kTileSize < 128 || + kThreadsPerBlock > + ir_emitter_context_->device_description().threads_per_block_limit() || + total_shared_memory_needed > + ir_emitter_context_->device_description().shared_memory_per_block(); + + uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); + LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + + auto emit_kernel = [&](absl::Span xor_masks) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = xor_masks.size() > 1 + ? tiled_launch_dimensions + : standard_launch_dimensions; + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } + return llvm_ir::EmitSortInPlace( + dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_masks, + &b_, launch_dimensions, + xor_masks.size() > 1 ? num_iterations_in_sort_dim + : standard_num_iterations_in_sort_dim, + kTileSize); + }; + std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { for (int64 mask = stage; mask >= 0; --mask) { - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys_shape, ir_emitter_context_->device_description()); - UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), - ir_emitter_context_->llvm_module()); - - llvm::Value* xor_mask; + int64 xor_mask; if (mask == stage) { - xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + xor_mask = (1LL << (stage + 1)) - 1; } else { - xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + xor_mask = 1LL << mask; } - - IrArray keys_array; - std::vector values_arrays; - values_arrays.reserve(sort->operand_count() - 1); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - if (i == 0) { - keys_array = GetIrArray(*sort, *sort, shape_index); - } else { - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + if (xor_mask >= kTileSize || no_tiling) { + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + xor_masks.clear(); } + TF_RETURN_IF_ERROR(emit_kernel({xor_mask})); + } else { + xor_masks.push_back(xor_mask); } - TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_mask, - &b_, &launch_dimensions)); } } + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + } AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); @@ -3261,13 +3339,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( param->shape().element_type(), module_), kTileSize + 1), kTileSize); - const int kNVPTXSharedMemoryAddrSpace = 3; - auto* tile_base_ptr = new llvm::GlobalVariable( - *b_.GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), - llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, - llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); + auto* tile_base_ptr = llvm_ir::AllocateSharedMemoryTile( + b_.GetInsertBlock()->getParent()->getParent(), tile_type, + IrName(hlo, StrCat("tile", id))); param_shmem_buffers[id] = tile_base_ptr; VLOG(3) << "Added shmem buffer for parameter " << id << ": " << llvm_ir::DumpToString(*tile_base_ptr); @@ -3454,6 +3528,29 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( return launch_dimensions; } +namespace { +// Returns true to indicate it is safe to use the tile based shared memory +// transpose implementation to implement the kernel for the instruction. +// +// An instruction is not safe for such an implementation if it can change the +// element order of a tensor without changing the dimension of the tensor, and +// the instruction has a corresponding elemental_ir_emitter. +bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { + auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { + HloOpcode opcode = instr->opcode(); + CHECK_NE(opcode, HloOpcode::kFusion); + return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); + }; + + if (hlo->opcode() == HloOpcode::kFusion) { + return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), + is_safe_for_tile_based_transpose); + } + + return is_safe_for_tile_based_transpose(hlo); +} +} // namespace + bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); @@ -3498,6 +3595,10 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } + if (!IsInstructionSafeForTileBasedTranspose(hlo)) { + return false; + } + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb // shared memory per SM. (This is increased to 96kb in Volta, but we don't diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 9427d3d54addc7d794ddc0a8f4c45b39b248bc5f..d9b06828e2b5d334873c88cb49c2e0d5675bb5fe 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -140,6 +140,18 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return false; } + // The emitter only supports in-place DUS for fusions with a single DUS at the + // root. Don't sibling fuse DUS for now. + // TODO(b/119178699): Multi-output fusing DUS can improve performance if we + // share the input and output buffers and add support to the emitter. + if (instr1->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice || + (instr2->opcode() == HloOpcode::kFusion && + instr2->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice)) { + return false; + } + // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 1d4856e0cae163bbd9ab741917b85792097d8512..dc221f22a74f0875e08e01890ce8ac8fe072cd9d 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); @@ -621,5 +621,39 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { } } +TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { + auto module = ParseHloString(R"(HloModule dus_mof + fusion.1 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + fusion.2 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + ENTRY entry { + p.00 = f16[50,96,1024]{2,1,0} parameter(0) + p.01 = f16[50,96,1024]{2,1,0} parameter(1) + p.1 = s32[1]{0} parameter(2) + p.2 = f16[1,96,1024]{2,1,0} parameter(3) + + f1 = f16[50,96,1024] fusion(p.00, p.1, p.2), kind=kLoop, calls=fusion.1 + f2 = f16[50,96,1024] fusion(p.01, p.1, p.2), kind=kLoop, calls=fusion.2 + ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2) + })") + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5b6cf2c04d05378a363232e33a6df6432cd6848e..4775baf44aecfe6adaf2bf0d2791595436635b16 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -122,7 +122,7 @@ std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = - computation.ComputeReachability(); + HloReachabilityMap::Build(&computation); std::vector seen_gemms; // The execution of different RNG Hlo instructions in the same module updates // a common global variable. To avoid a race condition, we simply assign all diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index c4f43cc9a614283acb376b5f98e4976615b590ad..f2ef11e1e6ac2405ac2a35fec7b79add9d2b6c17 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,16 +21,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloVerifiedTestBase { +class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index ed46f08d5970d479db33a7b9ad416a1480535764..d798b31643782eb25bba08227e29903ec0e7a597 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -37,7 +37,7 @@ cc_library( hdrs = ["gpu_codegen_test.h"], tags = tf_cuda_tests_tags(), deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 79e77d4c4d649020cf52ac25c220c3f90e8469b9..9e3ff8750b88d08bcbc1aae3faead5aecfa19848 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace gpu { -std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { +std::unique_ptr GpuCodegenTest::CreateNewUnverifiedModuleWithFTZ( + bool ftz) { HloModuleConfig config; - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_gpu_ftz(ftz); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index e4a3573babb7ed746504c1466f85b582aa4d044f..d2f30ae7bc4f65675f10a2f87ba934cf308f663a 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,9 +26,9 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz - // option. - std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + // Like HloTestBase::CreateNewUnverifiedModule(), with a flag for configuring + // the ftz option. + std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); // Compiles the given HLO module to PTX and verifies the PTX matches the given // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 780539c164277f14c2bd964024f7c3ca179f4ada..268b48a1cadeef911dfda7e827ae0cd154040be8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index 177b94934c7f519172508b5cc6e088f908401193..d0ccd8619bde9ddd560989380b403efed5c5f42c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -39,7 +39,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/1, param_shape, "y")); builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } @@ -54,7 +54,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/0, param_shape, "x")); builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index a06576df7b874745236a8d9075355a01ec42e777..da8e513a2c3b61eb9f780ac628e4befeb918b939 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 15d1e269cc22b88f5269175084f20600f165011c..a302b582ede3723acd118d2e4a4bb3efdf7a4d0b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -193,6 +193,33 @@ TEST_F(GpuKernelTilingTest, /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { + const char *const kHloString = R"( + HloModule FusionTransposeWithReverseNotTiled + fused_computation.1 { + arg0 = f32[128,64]{1,0} parameter(0) + copy0 = f32[128,64]{0,1} copy(arg0) + ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0} + } + + ENTRY reverse_break_assumption { + param0 = f32[128,64]{1,0} parameter(0) + ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9..ea1fee040dd536bcd1c4f8c5dd4f3aaa8dca32f9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 15198865bda98f9718342d5a444a20305f923b48..14285459b5a7fc0325dc5af80e57bef4ee4b7299 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8..4636f1d9d20b8c213ffadec427b3820a89c68a7f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -85,7 +85,7 @@ TEST_F(GpuUnrollingTest, UnrollFourTimes) { TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { // The default unrolling factor is 4. HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc index 5fa9e91050a85b67eb22a48d47e4dd157a53c699..3d00ac4dc7b57664a317157c093d7ffaa01b4fd6 100644 --- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -32,7 +32,7 @@ namespace gpu { namespace { using match::Concatenate; -class VariadicOpSplitterTest : public HloVerifiedTestBase {}; +class VariadicOpSplitterTest : public HloTestBase {}; TEST_F(VariadicOpSplitterTest, DontSplit) { auto module = ParseAndReturnVerifiedModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 926b59a1b854bd3d7d2699124e10b70147e52e2a..c7f51127649664189050e2128ae1e56547358c23 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewModule()), + : module_(CreateNewUnverifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index e30e7667f3015bc7bfe67c65147a5016332780f7..fad3215fc81e1012ddaa5a6458bc98f90ab97f18 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,16 +30,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; +class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module); + HloSchedule schedule(module.get()); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -351,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloVerifiedTestBase { +class HeapSimulatorTest : public HloTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 5c8d97b2d15e15d15cb8014a7d25b37437ce8aec..7e6150e94153cd15463725e862ce1b8593f2c991 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,17 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloVerifiedTestBase { +class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : HloVerifiedTestBase() { - module_ = CreateNewModule(); + HloAliasAnalysisTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_, + analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -93,7 +93,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -110,7 +110,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { return false; } - HloModule* module_; + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -638,7 +638,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_).status()); + TF_ASSERT_OK(flattener.Run(module_.get()).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -1012,7 +1012,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -1054,13 +1054,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - HloSchedule schedule(module_); + HloSchedule schedule(module_.get()); schedule.set_sequence(body, {body_param, body_root}); schedule.set_sequence(condition, {cond_param, cond_root}); { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b0f7cd91ad1db0a59c09cfbfc1885813dc57e01e..0c20d207ddbca82e2f87800d331d1bace39a512e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,7 +322,7 @@ void HloComputation::ComputeInstructionPostOrder( // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. + // result for things like HLO stringification. const auto& operands = current->operands(); for (int64 i = operands.size() - 1; i >= 0; --i) { dfs_stack.emplace_back(operands[i]); @@ -739,72 +740,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -std::unique_ptr HloComputation::ComputeReachability() - const { - const auto& all = MakeInstructionPostOrder(); - auto result = absl::make_unique(all); - auto channel_dependency_map = ComputeChannelDependencies(); - - std::vector inputs; - for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); - - switch (hlo->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - break; - } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - } - break; - } - default: - break; - } - - result->FastSetReachabilityToUnion(inputs, hlo); - } - return result; -} - -void HloComputation::UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map) { - std::queue worklist; - worklist.push(instruction); - - std::vector inputs; - - while (!worklist.empty()) { - const HloInstruction* item = worklist.front(); - worklist.pop(); - - inputs.assign(item->operands().begin(), item->operands().end()); - inputs.insert(inputs.end(), item->control_predecessors().begin(), - item->control_predecessors().end()); - - if (reachability_map->SetReachabilityToUnion(inputs, item)) { - // Add immediate successors to worklist. - for (const HloInstruction* user : item->users()) { - worklist.push(user); - } - for (const HloInstruction* succ : item->control_successors()) { - worklist.push(succ); - } - } - } -} - std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; for (auto* instruction : instructions()) { @@ -911,14 +846,46 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - /*extras=*/{}, context, suffix); + context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + replacements.emplace(std::move(r3)); + return CloneWithReplacements(std::move(replacements), context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context, - const string& suffix) { + HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -939,18 +906,50 @@ std::unique_ptr HloComputation::CloneWithReplacements( }; VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + + // We want to do a postorder walk over [replace(i) for i in instructions_]. + // We can't reuse MakeInstructionPostOrder() for this, because that will + // generate a postorder of plain instructions_, and our replacements may + // change the postorder! + // + // The postorder we want here is simpler than what MakeInstructionPostOrder() + // does -- we only care about operand dependencies -- so let's just do it + // ourselves. std::vector postorder; - for (HloInstruction* instr : extras) { - postorder.push_back(instr); - } - for (HloInstruction* instr : MakeInstructionPostOrder()) { - if (HloInstruction* replacement = replace(instr)) { - postorder.push_back(replacement); + absl::flat_hash_map visited; + for (const auto& instr : instructions_) { + std::vector dfs_stack; + HloInstruction* new_instr = replace(instr.get()); + if (!new_instr) { + continue; + } + dfs_stack.push_back(new_instr); + + while (!dfs_stack.empty()) { + auto* cur = dfs_stack.back(); + auto it = visited.find(cur); + if (it != visited.end()) { + dfs_stack.pop_back(); + if (it->second == kVisited) { + continue; + } + CHECK_EQ(it->second, kVisiting); + postorder.push_back(cur); + it->second = kVisited; + continue; + } + + visited.insert({cur, kVisiting}); + for (HloInstruction* operand : cur->operands()) { + HloInstruction* new_operand = replace(operand); + if (new_operand) { + dfs_stack.emplace_back(new_operand); + } + } } } std::vector> instructions; - std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { @@ -960,9 +959,8 @@ std::unique_ptr HloComputation::CloneWithReplacements( << operand->ToString() << ", used by " << instr->ToString(); new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, context); - instructions.push_back(std::move(new_instr)); + instructions.push_back( + instr->CloneWithNewOperands(instr->shape(), new_operands, context)); } Builder builder(name() + "." + suffix); for (auto& instr : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index dec96d11a93cf56d3c40a6bb7882ffb7336aeeb0..fc7d2035e5bd0b99fa9e7a026430172f686019d4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" @@ -215,19 +214,6 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - std::unique_ptr ComputeReachability() const; - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this @@ -333,14 +319,38 @@ class HloComputation { // the map's value to replace that instruction in the cloned computation. // // If replacements maps a key to nullptr, we remove that instruction from the - // new computation. - // If additional instructions are used by instructions in replacement map, - // they must be passed in post-order in the extras span. + // new computation. If an element of `replacements` references an instruction + // that's not already in the computation, it's cloned and added to the new + // computation. + // + // All relevant instructions are cloned, *including* unique_ptr in the + // `replacements` map. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); + + // Convenience overloads for CloneWithReplacements. You want to do + // + // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR + // + // but that doesn't work because std::initializer_list is not movable. These + // overloads let you do + // + // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK + // + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -355,6 +365,14 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. bool HasSideEffect() const; @@ -410,14 +428,6 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = - absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; - enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 2aaaef1d36d58bcce18db4aa37ff05ea352e484b..1e7a6e197f5b6c3070b7cad2c14f62521290a4c9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -151,7 +151,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,7 +257,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -274,7 +274,7 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -376,7 +376,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -393,7 +393,7 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); @@ -412,7 +412,7 @@ TEST_F(HloComputationTest, CycleDetection) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); @@ -440,7 +440,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); @@ -466,7 +466,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -484,107 +484,6 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } -TEST_F(HloComputationTest, Reachability) { - // Test reachability of a non-trivial computation: - // - // const1 const2 - // | | - // | +-------+ - // | | | - // add .. negate - // | . | - // | .... exp - // | | - // +---+ +-+---+ - // | | | - // multiply copy - // - // There is a control dependency from 'add' to 'exp'. - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - - auto module = CreateNewModule(); - auto computation = - module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); - - TF_CHECK_OK(add->AddControlDependencyTo(exp)); - auto reachability = computation->ComputeReachability(); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_TRUE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_TRUE(reachability->IsReachable(constant1, copy)); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_TRUE(reachability->IsReachable(constant2, negate)); - EXPECT_TRUE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_TRUE(reachability->IsReachable(constant2, copy)); - - EXPECT_FALSE(reachability->IsReachable(exp, constant1)); - EXPECT_FALSE(reachability->IsReachable(exp, constant2)); - EXPECT_FALSE(reachability->IsReachable(exp, add)); - EXPECT_FALSE(reachability->IsReachable(exp, negate)); - EXPECT_TRUE(reachability->IsReachable(exp, exp)); - EXPECT_TRUE(reachability->IsReachable(exp, mul)); - EXPECT_TRUE(reachability->IsReachable(exp, copy)); - - EXPECT_FALSE(reachability->IsReachable(mul, constant1)); - EXPECT_FALSE(reachability->IsReachable(mul, constant2)); - EXPECT_FALSE(reachability->IsReachable(mul, add)); - EXPECT_FALSE(reachability->IsReachable(mul, negate)); - EXPECT_FALSE(reachability->IsReachable(mul, exp)); - EXPECT_TRUE(reachability->IsReachable(mul, mul)); - EXPECT_FALSE(reachability->IsReachable(mul, copy)); - - EXPECT_TRUE(reachability->IsConnected(constant1, copy)); - EXPECT_TRUE(reachability->IsConnected(copy, constant1)); - EXPECT_FALSE(reachability->IsConnected(negate, add)); - EXPECT_FALSE(reachability->IsConnected(add, negate)); - - // Remove the control dependency then update and verify the reachability map - ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); - computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_FALSE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_FALSE(reachability->IsReachable(constant1, copy)); - - // Change a use within the graph then update and verify the reachability map - ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); - computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_FALSE(reachability->IsReachable(constant2, negate)); - EXPECT_FALSE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_FALSE(reachability->IsReachable(constant2, copy)); -} - TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,7 +505,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -641,7 +540,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -677,7 +576,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -700,27 +599,5 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -TEST_F(HloComputationTest, ChannelReachability) { - const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); - HloComputation::Builder builder("ChannelReachability"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); - auto send = - builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); - auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); - auto recv = - builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); - auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(recv_done)); - auto reachability = computation->ComputeReachability(); - EXPECT_TRUE(reachability->IsReachable(param, recv_done)); - EXPECT_FALSE(reachability->IsReachable(send, recv)); - EXPECT_FALSE(reachability->IsReachable(send_done, recv)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 4f898ce61c3f36e83e4b13130a404dbb4a2c36c6..5e37883d3d8d5067bab873ac6b5f732e7360c5fa 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -52,8 +52,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, and AfterAll operation. - // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // Skip Constant, Parameter, Tuple, AfterAll operation. + // Tuple constants are not directly supported by any backends, hence + // folding Tuple is not useful and would in fact be expanded back into + // kTuple by Algebraic Simplifier. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this // special case is not necessary. @@ -63,6 +65,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { instruction->opcode() == HloOpcode::kAfterAll) { continue; } + // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index e45f905f7152c37a9ab2b41d407310671310c2a3..d12f920722e20a3390a99f74c8a10c7c9e3fdf6c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloVerifiedTestBase; +using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -46,13 +46,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -67,13 +67,13 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -88,13 +88,13 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -130,11 +130,11 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -157,11 +157,11 @@ TEST_F(HloConstantFoldingTest, Slice) { Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); builder.AddInstruction(HloInstruction::CreateSlice( shape, literal_instruction, slice_start, slice_limits, slice_strides)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -182,11 +182,11 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -219,27 +219,28 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - ParseAndVerifyModule(kConstantFoldReduce); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_TRUE(result); - EXPECT_EQ(6, module() - .entry_computation() + EXPECT_EQ(6, m->entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - ParseAndVerifyModule(kConstantFoldReduce); - HloInstruction* add = module().computations().begin()->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); + HloInstruction* add = m->computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reduce()); } const char* const kConstantFoldLargePad = R"( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 108aeea097d7170d236b988c414b517a1a284640..fdfb38b858c32ba5b092ec2db84d4bac487c3e78 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -269,7 +269,7 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(map->to_apply())); + ProcessNestedSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64 element_count = ShapeUtil::ElementsIn(map->shape()); @@ -285,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. // This counts the number of times the reduction function is applied, so it @@ -311,7 +311,7 @@ Status HloCostAnalysis::HandleReduceWindow( auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each // output element there are window_size - 1 reductions to perform. @@ -336,9 +336,9 @@ Status HloCostAnalysis::HandleSelectAndScatter( // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, - ProcessSubcomputation(instruction->select())); + ProcessNestedSubcomputation(instruction->select())); TF_ASSIGN_OR_RETURN(const Properties scatter_properties, - ProcessSubcomputation(instruction->scatter())); + ProcessNestedSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter // source element there are window_size - 1 select computations to perform and @@ -574,7 +574,7 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); + ProcessNestedSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -595,7 +595,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, - ProcessSubcomputation(call->to_apply())); + ProcessUnnestedSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -624,13 +624,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. - // - // TODO(b/26346211): Improve the cost analysis for while nodes. TF_ASSIGN_OR_RETURN(const Properties body_properties, - ProcessSubcomputation(xla_while->while_body())); + ProcessUnnestedSubcomputation(xla_while->while_body())); - TF_ASSIGN_OR_RETURN(const Properties condition_properties, - ProcessSubcomputation(xla_while->while_condition())); + TF_ASSIGN_OR_RETURN( + const Properties condition_properties, + ProcessUnnestedSubcomputation(xla_while->while_condition())); current_properties_.clear(); for (const auto& property : body_properties) { @@ -647,10 +646,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { // Compute the cost of the true and false computations and take the maximum // from those for each property. - TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, - ProcessSubcomputation(conditional->true_computation())); - TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, - ProcessSubcomputation(conditional->false_computation())); + TF_ASSIGN_OR_RETURN( + const Properties true_computation_properties, + ProcessUnnestedSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN( + const Properties false_computation_properties, + ProcessUnnestedSubcomputation(conditional->false_computation())); current_properties_ = true_computation_properties; for (const auto& property : false_computation_properties) { if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { @@ -680,7 +681,7 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { const int64 element_count = ShapeUtil::ElementsIn(scatter->operand(2)->shape()); TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(scatter->to_apply())); + ProcessNestedSubcomputation(scatter->to_apply())); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * element_count; @@ -689,6 +690,11 @@ Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { return Status::OK(); } +Status HloCostAnalysis::HandleGetDimensionSize( + const HloInstruction* /*get_size*/) { + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -725,10 +731,19 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } -StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation) { +StatusOr +HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); +} + +StatusOr +HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + hlo_properties_.insert(visitor.hlo_properties_.begin(), + visitor.hlo_properties_.end()); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 46b4bbeef222e6de581360fc01b293e812f1dedd..8ced9d776e150ac587e9ac3ed0beffbc38dc5503 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -107,6 +107,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; Status HandleScatter(const HloInstruction* scatter) override; + Status HandleGetDimensionSize(const HloInstruction* get_size) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -153,7 +154,24 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Returns the properties computed from visiting the computation rooted at the // given hlo. - StatusOr ProcessSubcomputation(HloComputation* computation); + // + // The difference between ProcessNestedSubcomputation and + // ProcessUnnestedSubcomputation is that we expect to get profile results for + // an unnested subcomputation's individual instructions, while we expect that + // a nested subcomputation is completely subsumed by its parent. + // + // For example, subcomputations inside kFusion and kMap are considered nested, + // while subcomputations inside kWhile and kConditional are considered + // unnested. + // + // Another way of thinking of this is, kFusion is implemented on the GPU + // backend using just one GPU kernel, while kWhile's body is implemented as a + // sequence of kernels, one for each HLO therein. Backends don't necessarily + // need to follow this same implementation strategy, but we assume they do for + // the purposes of this platform-generic cost analysis. + StatusOr ProcessNestedSubcomputation(HloComputation* computation); + StatusOr ProcessUnnestedSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9acee892d5993be3498d51ed66d7fa4647d7de88..6a15b3440c6f9bd2cac5ea10a0883330260b89e5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8..aaa9ec60eb3c4e0159ed40b37d772e0973d306ec 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,22 +19,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -class HloCreationUtilsTest : public HloVerifiedTestBase { +class HloCreationUtilsTest : public HloTestBase { protected: - HloModule* CreateModuleWithProgramShape( + std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, absl::Span input_shape_dims, absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - auto module = CreateNewModule("test"); + auto module = CreateNewVerifiedModule("test"); *entry_computation = module->AddEntryComputation( CreateComputationWithSignature({&input_shape}, output_shape, "entry") .ValueOrDie()); @@ -47,10 +47,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); @@ -67,9 +66,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + auto module = CreateModuleWithProgramShape( + S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, @@ -92,10 +90,9 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{1, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -113,10 +110,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -134,10 +130,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{1, 1}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -154,10 +149,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6}, + /*output_shape_dims=*/{3, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, ExpandFirstDimIntoNDims(param, {3, 1, 2})); @@ -176,10 +170,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{6}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -197,10 +190,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -218,10 +210,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(F32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9b18b0284f63c25934c1b7118dc8973caa62cadc..1eb0260468c4560985027947e89c62cc21139e7e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloVerifiedTestBase { +class HloCseTest : public HloTestBase { protected: HloCseTest() {} }; @@ -59,13 +59,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); @@ -89,14 +89,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); @@ -121,14 +121,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { shape_r0, HloOpcode::kAdd, root, constants[i])); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -201,7 +201,7 @@ TEST_F(HloCseTest, NonscalarConstants) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( {common_constant1, common_constant2, uncommon_constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -233,14 +233,14 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -277,21 +277,21 @@ index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -327,20 +327,20 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -369,22 +369,21 @@ condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), condition=%condition.1, body=%body - } - - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -411,13 +410,14 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -439,14 +439,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -470,14 +470,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -488,7 +488,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { TEST_F(HloCseTest, FusionInternalCSE) { // Test that we can CSE expressions that live within a fusion node // computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); @@ -512,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -554,14 +554,14 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -586,7 +586,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -595,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -607,7 +607,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // rng_function is an impure function because it does RNG. HloComputation* rng_function = nullptr; @@ -649,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -659,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule m add_computation { @@ -680,11 +680,12 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })"); + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -697,19 +698,19 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -730,11 +731,12 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})"); +})"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - const HloInstruction* sub = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + const HloInstruction* sub = m->entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 909853106d57d181e85e3e4134b4039be2b176f5..6422346c1011b95bb511a1fcdfee5c84647f0571 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewUnverifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1884,7 +1884,7 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, class HloDataflowAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -2476,7 +2476,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -2511,7 +2511,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 3b5cde2996c4195ef458662cd21de85a832d8d55..6c8095d39774b247e136442c92c8ecf17432701c 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -80,7 +80,7 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -215,7 +215,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { // Tests that a while instruction with an infeed (effectul instruction) in its // body is not removed, even its user count is 0. TEST_F(HloDceTest, CalledComputationWithSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Condition computation of a while instruction. @@ -270,7 +270,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { // Tests that a nested call instruction with a side effect is not removed. TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Nested called computation with a side effect. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); @@ -364,7 +364,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { } TEST_F(HloDceTest, KeepUsedSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index b90e8db23398d23e886d2d1fe68de8bb187d9c3a..acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" @@ -22,13 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloVerifiedTestBase { +class HloDomainTest : public HloTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,13 +63,6 @@ class HloDomainTest : public HloVerifiedTestBase { } return false; } - - StatusOr ParseModule(absl::string_view hlo_string) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_string, config); - return &module(); - } }; // Dummy DomainMetadata implementation which create kDomain boundaries around @@ -144,31 +136,32 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,11 +179,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(!isolator_changed); } @@ -213,26 +207,27 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); - EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -250,11 +245,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_FALSE(isolator_changed); } @@ -273,15 +269,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module, "c"); + HloInstruction* add = FindInstruction(module.get(), "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -304,41 +301,42 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module)); + sharding_isolator.Run(module.get())); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -359,16 +357,17 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); - EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. @@ -384,7 +383,7 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data"); ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); @@ -410,7 +409,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); struct Assignment { @@ -446,25 +445,26 @@ ENTRY entry { sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); - EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple")); // Remove %tuple and %gte (tuple simplification) - HloInstruction* gte = FindInstruction(module, "gte"); - HloInstruction* tuple = FindInstruction(module, "tuple"); + HloInstruction* gte = FindInstruction(module.get(), "gte"); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); const HloInstruction* root = module->entry_computation()->root_instruction(); @@ -486,11 +486,11 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto hlo_string = module->ToString(); - ASSERT_TRUE(ParseModule(hlo_string).status().ok()); + ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok()); } // Tuple inputs are domain instructions. @@ -507,20 +507,21 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tpl instruction, in order to test domain sharding // application. - auto tpl = FindInstruction(module, "tpl"); + auto tpl = FindInstruction(module.get(), "tpl"); tpl->clear_sharding(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), @@ -555,36 +556,37 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } // Emulate instructions inserted at top and bottom within nested tuple domain. @@ -603,15 +605,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tuple.0 instruction, in order to test domain sharding // application. - auto tuple0 = FindInstruction(module, "tuple.0"); + auto tuple0 = FindInstruction(module.get(), "tuple.0"); tuple0->clear_sharding(); // Insert the following instructons above and below tuple.0, to emulate other @@ -655,7 +658,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_TRUE(tuple0->has_sharding()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 608a42bb60702aa075daca39535ca1672dcc5467..d95b6ad04f2c446b423a3aaef4de333ed2968883 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,9 +50,9 @@ namespace { static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloVerifiedTestBase { + public HloTestBase { protected: - HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -60,14 +60,14 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(&module()).ValueOrDie(); + type_converter.Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - // Evaluate function that takes in a local module instead of using module_ - // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // Evaluate function that takes in a local module instead of using m_ + // that is in HloTestBase. Once m_ in HloTestBase is // removed, this should be the default Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { @@ -88,7 +88,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -108,7 +108,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -116,6 +116,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } bool use_bfloat16_; + std::unique_ptr m_ = CreateNewVerifiedModule(); }; #define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ @@ -135,7 +136,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -156,7 +157,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -181,7 +182,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -322,7 +323,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(args); @@ -346,7 +347,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -367,7 +368,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, {1, 2})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -386,7 +387,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -406,7 +407,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -428,7 +429,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -448,7 +449,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -468,7 +469,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -503,7 +504,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -530,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -574,7 +575,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -619,7 +620,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -658,7 +659,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -704,7 +705,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -748,7 +749,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -802,7 +803,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -857,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -941,7 +942,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1019,7 +1020,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1079,7 +1080,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1143,7 +1144,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1215,7 +1216,7 @@ TEST_P(HloEvaluatorTest, b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1286,7 +1287,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1297,11 +1298,12 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1319,12 +1321,12 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m->AddEmbeddedComputation(add_computation.Build()); HloInstruction* reduce_instruction = b.AddInstruction( HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{0}, add_func)); - module().AddEntryComputation(b.Build()); + m->AddEntryComputation(b.Build()); HloEvaluator hlo_eval; Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); @@ -1337,7 +1339,7 @@ void BM_ReducePrecisely(int num_iters) { tensorflow::testing::StopTiming(); HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1396,14 +1398,14 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1438,7 +1440,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1455,7 +1457,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1490,7 +1492,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1507,7 +1509,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1541,7 +1543,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1564,7 +1566,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1594,7 +1596,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1625,7 +1627,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1657,7 +1659,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1691,7 +1693,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1727,7 +1729,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1764,7 +1766,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1800,7 +1802,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1839,7 +1841,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1877,7 +1879,7 @@ TEST_P(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1966,7 +1968,7 @@ ENTRY main { slice_sizes={1, 3} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -1990,7 +1992,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -2014,7 +2016,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2039,7 +2041,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2066,7 +2068,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2092,7 +2094,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); @@ -2115,7 +2117,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2139,7 +2141,7 @@ ENTRY main { slice_sizes={1, 0} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), @@ -2161,7 +2163,7 @@ ENTRY main { slice_sizes={1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = @@ -2192,7 +2194,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2223,7 +2225,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2256,7 +2258,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2288,7 +2290,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2320,7 +2322,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); @@ -2354,7 +2356,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2386,7 +2388,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2418,7 +2420,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2455,7 +2457,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2491,7 +2493,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2523,7 +2525,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2555,7 +2557,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); @@ -2585,7 +2587,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal scatter_indices = @@ -2736,7 +2738,7 @@ ENTRY main { ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); @@ -2754,7 +2756,7 @@ ENTRY main { ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 13a74fd8a115c5dc9a9518b226dfee4445cc7180..05cc1593e4ef4fc52b94e0536628645b1fa2abbc 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1043,6 +1043,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGetDimensionSize: return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f6ed86b41650fd331201814559386ff644092c23..26786ee950b5421f79fc089d65f1395aae65d335 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -312,6 +312,10 @@ StatusOr> HloInstruction::CreateFromProto( proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { + TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) && + (ShapeUtil::TupleElementCount(proto.shape()) == 2)) + << "Infeed should have a tuple shape with 2 operands, but has: " + << proto.shape(); const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -530,6 +534,12 @@ StatusOr> HloInstruction::CreateFromProto( absl::make_unique(exit_hlo_sharding)); break; } + case HloOpcode::kGetDimensionSize: + TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.dimensions_size() == 1); + instruction = CreateGetDimensionSize(proto.shape(), operands(0), + proto.dimensions(0)); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -1001,6 +1011,14 @@ HloInstruction::CreateSelectAndScatter( broadcast_dimensions); } +/* static */ std::unique_ptr +HloInstruction::CreateGetDimensionSize(const Shape& shape, + HloInstruction* operand, + int64 dimension) { + return absl::make_unique(shape, operand, + dimension); +} + /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -1109,7 +1127,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) { void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { - if (sharding_ != nullptr) { + if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType( + shape_, derived_instruction->shape())) { + // Only copy sharding if the shape of the two instruction is compatible + // because copying it between differently shaped instructions can produce + // invalid shardings. derived_instruction->set_sharding(*sharding_); } else { derived_instruction->clear_sharding(); @@ -1268,6 +1290,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kIota: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1715,6 +1738,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kScatter: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2440,6 +2464,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAfterAll(this); case HloOpcode::kIota: return visitor->HandleIota(this); + case HloOpcode::kGetDimensionSize: + return visitor->HandleGetDimensionSize(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2639,49 +2665,7 @@ Status HloInstruction::Accept( return this->Accept(&visitor); } -Status HloInstruction::AcceptOrdered( - DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; - TF_RET_CHECK(OrderIsTopologicalSort(order)); - - // Compute the predecessors of this instruction. - std::unordered_set predecessors; - TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { - predecessors.insert(instruction); - return Status::OK(); - })); - - for (auto* const_instruction : order) { - if (!ContainsKey(predecessors, const_instruction)) { - // Instruction is not a predecessors of 'this'. - continue; - } - - // The visitor can mark instructions as visited to skip particular - // instructions. - if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO %" << const_instruction->name() - << " as it was already visited."; - continue; - } - - // TODO(b/78350259): Eliminate const laundering. - HloInstruction* instruction = - const_cast(const_instruction); - - TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO %" << instruction->name(); - TF_RETURN_IF_ERROR(instruction->Visit(visitor)); - visitor->SetVisited(*instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); - } - - return visitor->FinishVisit(this); -} - -const Shape& HloInstruction::shape() const { - return shape_; -} +const Shape& HloInstruction::shape() const { return shape_; } std::vector HloInstruction::OperandIndices( const HloInstruction* operand) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 15a4da8dbe0053aad314989a6718ebd61532ab8b..818d4ede0f30f06d390daa70c508c6be6bbc38ce 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -767,6 +767,9 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr CreateToken(); + static std::unique_ptr CreateGetDimensionSize( + const Shape& shape, HloInstruction* operand, int64 dimension); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -954,16 +957,6 @@ class HloInstruction { Status Accept( const std::function& visitor_func) const; - // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain at least the set of instructions - // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). Instructions contained in 'order' which are not in the set of - // instructions rooted at this node are ignored. 'order' must also be a valid - // topological sort of these instructions (defs appear before uses) though - // need not be a DFS post-order. - Status AcceptOrdered(DfsHloVisitor* visitor, - const std::vector& order); - // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d93351fe0435b5f29035dc4ea0621a8c576bfd5a..8048e332cb57747286758b75773b29ba154aa888 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,7 +39,7 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloVerifiedTestBase { +class HloInstructionTest : public HloTestBase { protected: Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -151,7 +151,7 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); @@ -188,7 +188,7 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -221,7 +221,7 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -256,7 +256,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -305,7 +305,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -327,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -375,7 +375,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. @@ -416,7 +416,7 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -451,7 +451,7 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -479,7 +479,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -516,7 +516,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -546,7 +546,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -611,7 +611,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; @@ -629,7 +629,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -647,7 +647,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -669,7 +669,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -692,7 +692,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -749,7 +749,7 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); @@ -817,7 +817,7 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -977,7 +977,7 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); int visit_num = 0; @@ -1006,7 +1006,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); @@ -1016,7 +1016,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { } TEST_F(HloInstructionTest, MapIsElementwise) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); HloComputation::Builder builder(TestName()); HloComputation::Builder map_builder("id"); @@ -1067,7 +1067,7 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); @@ -1108,7 +1108,7 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); @@ -1151,7 +1151,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1192,7 +1192,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1204,7 +1204,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { } TEST_F(HloInstructionTest, FusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1226,7 +1226,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1330,7 +1330,7 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( @@ -1373,7 +1373,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1408,7 +1408,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1443,7 +1443,7 @@ TEST_F(HloInstructionTest, StringifyScatter) { update_builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* update_computation = module->AddEmbeddedComputation(update_builder.Build()); @@ -1495,7 +1495,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1531,7 +1531,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1587,7 +1587,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 88495e80000c4f87a778c4fad747f6bdf09b7a14..4c765aa375cd788612d144484df041dd6cd989f4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2349,4 +2349,43 @@ HloInstructionProto HloDomainInstruction::ToProto() const { return proto; } + +HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction( + const Shape& shape, HloInstruction* operand, int64 dimension) + : HloInstruction(HloOpcode::kGetDimensionSize, shape), + dimension_(dimension) { + AppendOperand(operand); +} + +HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(dimension()); + return proto; +} + +std::vector HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + return {StrCat("dimensions={", dimension(), "}")}; +} + +bool HloGetDimensionSizeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return dimension() == casted_other.dimension(); +} + +std::unique_ptr +HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + if (new_operands.size() != 1) { + LOG(FATAL) << "expects 1 operand"; + } + return absl::make_unique( + shape, new_operands[0], dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index bf4daf2be47ed06d2b88a331a56149d38fa646b3..d43a8973ccff697c27462b611446215df71973a5 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1385,6 +1385,33 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; }; + +class HloGetDimensionSizeInstruction : public HloInstruction { + public: + explicit HloGetDimensionSizeInstruction(const Shape& shape, + HloInstruction* operand, + int64 dimension); + + // Returns the dimension sizes or numbers associated with this instruction. + int64 dimension() const { return dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 1717770301e3666b0a1c23d20b7f2e3bac5f62e4..170ec93a334903cdc314f1950675ef30bc4cda5a 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -165,6 +165,7 @@ namespace opcode_matchers { } HLO_MATCHER(Abs); HLO_MATCHER(Add); +HLO_MATCHER(AllToAll); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); HLO_MATCHER(BatchNormGrad); @@ -178,6 +179,7 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 2f15997fc175c44121bbf0ace2940fa01f465f92..984a6266abb28f154a015e79645317e4e246fd0b 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -172,7 +172,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, ScheduleModule(*module, @@ -218,7 +218,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( @@ -242,7 +242,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { } TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); // param != 0 diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 6a838b7eb969d514dad6bf62a37bf84cd96de8de..14bf17f4be16f8cf820753bc9f0473029834f1f8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -559,7 +559,8 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { std::unique_ptr HloModule::Clone(const HloModuleConfig& config, const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config); + auto module = absl::make_unique( + absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 39f38b417ab0e8b54864176d8d1e0ad1a422eca6..3ae67e4e5ee90ca182c7c3d97a67d070431ce851 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewModule(); - auto module_b = CreateNewModule(); + auto module_a = CreateNewUnverifiedModule(); + auto module_b = CreateNewUnverifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e6bfb8025d4bfeba1d334d1f946e33841a2da092..70c7d70b41c5c7bc94d1fac83c0fcf71f155b5f0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -83,6 +83,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetDimensionSize, "get-dimension-size") \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 23d41d91d6969ddf9062507e926ae39c1e1315d4..f5f99bece18cc637365118ddcd1273da05f4e1b6 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -334,7 +334,7 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto* computation : module->MakeNonfusionComputations()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, HloReachabilityMap::Build(computation)); } } @@ -374,11 +374,10 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( return order_position_.at(a) < order_position_.at(b); } -const std::vector* -SequentialHloOrdering::SequentialOrder( +const HloInstructionSequence* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { return schedule_.is_computation_scheduled(&computation) - ? &schedule_.sequence(&computation).instructions() + ? &schedule_.sequence(&computation) : nullptr; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 66313492eb2dd10ac9a6000639ddb8991b367c0f..a07214c22c0989a438f12219e136a7e76ee0dcce 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +65,7 @@ class HloOrdering { // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. - virtual const std::vector* SequentialOrder( + virtual const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const = 0; // Return the call graph of the module used to compute ordering. @@ -96,7 +97,7 @@ class PredecessorHloOrdering : public HloOrdering { // Returns nullptr indicating the computation does not have a sequential // ordering. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return nullptr; } @@ -185,7 +186,7 @@ class SequentialHloOrdering : public HloOrdering { ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override; string ToString() const override; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f..2ab8aa57f6ed4586c3376ee7c44126c0ed19ea0b 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 450660b94b783bd391a300859c12a3e515f518b9..4390145c6bd7484987b2851ef92336defffb388b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -108,7 +108,7 @@ class HloParser { bool ParseInstructionList(HloComputation** computation, const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); @@ -608,10 +608,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } - return ParseInstruciontRhs(builder, name, name_loc); + return ParseInstructionRhs(builder, name, name_loc); } -bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, +bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc) { Shape shape; HloOpcode opcode; @@ -1547,6 +1547,18 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); + case HloOpcode::kGetDimensionSize: + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateGetDimensionSize( + shape, operands[0], (*dimensions)[0])); + break; } instruction->SetAndSanitizeName(name); @@ -2708,7 +2720,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - std::vector split1 = absl::StrSplit(str, "_"); + std::vector split1 = absl::StrSplit(str, '_'); if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; @@ -3389,7 +3401,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { // e.g. // // f32[10] fusion(...), calls={...} - if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) { return false; } } else { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 81eeb9f13bf7f06123c0b35e9f3352c197866a7a..d830fa61438239005875f785f85cf2486123ebc9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,7 +44,9 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, +// e.g., "{replicated}". StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). @@ -55,10 +57,6 @@ StatusOr ParseWindow(absl::string_view str); StatusOr ParseConvolutionDimensionNumbers( absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(absl::string_view str); - // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index eae6d19792f72b61bdc3d5374fc43f86baaed532..c59bdc0a0b372d829ee61f0a048b7704498e0d0e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1150,6 +1150,25 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } +)" +}, +// cross-replica-sum with all-reduce-id +{ +"CrossReplicaSumAllReduce", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add +} + )" }, // all-to-all diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index ee8cb12b231718e09f6ac0d05d7a6887f4c4d746..20384b9da6be4bab447b474f0e2240bcb277a620 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloPassPipelineTest : public HloVerifiedTestBase { +class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2d5197be9e6f69f698729e06b7506a5bc6260bcd..f968a4a94453f678f5c17e0b8d1df4aea70c93ea 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -104,5 +104,20 @@ bool IsScalarConstant(const HloInstruction* instruction) { return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape()); } +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes) { + for (const auto* instr : comp->instructions()) { + if (opcodes.count(instr->opcode())) { + return true; + } + for (const HloComputation* subcomp : instr->called_computations()) { + if (ContainsInstrWithOpcode(subcomp, opcodes)) { + return true; + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c0826a6aee1f693484207a86ec258c6604d92318..215051f8834fc94eb9e32b508f34b13626ac9349 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -41,6 +43,12 @@ bool AllOperandsAreConstants(const HloInstruction& instruction); // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); +// Determines whether the given computation contains an instruction with one of +// the given opcodes. Checks both comp's instructions and the instructions of +// any computations nested within it. +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes); + // Returns an operand of an instruction with the given opcode. If there are // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 961930f0a888e90f86e4354fa1373a303af8ec2f..4aa8067752481ffab29e1a573ffa49d4aa046f1f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { @@ -22,7 +24,7 @@ HloReachabilityMap::HloReachabilityMap( : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { - indices_[hlo] = bit_vectors_.size(); + indices_[GetKey(hlo)] = bit_vectors_.size(); bit_vectors_.emplace_back(size_); } CHECK_EQ(size_, indices_.size()); // instructions should be unique @@ -71,4 +73,70 @@ bool HloReachabilityMap::IsConnected(const HloInstruction* a, return IsReachable(a, b) || IsReachable(b, a); } +std::unique_ptr HloReachabilityMap::Build( + const HloComputation* computation) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + auto channel_dependency_map = computation->ComputeChannelDependencies(); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + +void HloReachabilityMap::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction) { + std::queue worklist; + worklist.push(instruction); + + std::vector inputs; + + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); + + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); + + if (SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 3e27d098aeb340aae471297e36a7e2824bcd0994..7823b06a41b3052f6f50f7ffa358de5b23ba679f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -16,27 +16,30 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#include #include #include +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace xla { -class HloInstruction; - // A class for representing reachability between HloInstructions. // -// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix -// and it is up to the user of the class to set the adjacency matrix such that -// it represents reachability, i.e. such that it is transitive. That the graph -// be transitive is thus not an invariant of this class, but it is required for -// the name of the class and its methods to make sense. +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given @@ -44,6 +47,15 @@ class HloReachabilityMap { explicit HloReachabilityMap( absl::Span instructions); + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true @@ -70,6 +82,10 @@ class HloReachabilityMap { // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + // Returns true if "b" is reachable from "a" // // Note that this function only correctly answers queries about reachability @@ -83,7 +99,9 @@ class HloReachabilityMap { bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; // Checks if an instruction is in the Reachability map. - bool IsPresent(const HloInstruction* a) const { return indices_.contains(a); } + bool IsPresent(const HloInstruction* a) const { + return indices_.contains(GetKey(a)); + } private: // A bit-vector implementation specialized for this use case which provides a @@ -146,18 +164,24 @@ class HloReachabilityMap { absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); + uint64 GetKey(const HloInstruction* instruction) const { + uint64 unique_id = absl::bit_cast(instruction->unique_id()); + uint64 module_id = + absl::bit_cast(instruction->parent()->parent()->unique_id()); + return (module_id << 32) | unique_id; + } // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { - return FindOrDie(indices_, instruction); + return FindOrDie(indices_, GetKey(instruction)); } // The number of instructions in the reachability map. const size_t size_; - // Dense assignment from HloInstruction* to number. These numbers index - // into the bit_vectors_ vector and into the bits within a BitVector. - absl::flat_hash_map indices_; + // Dense assignment from HloInstruction::unique_id to number. These numbers + // index into the bit_vectors_ vector and into the bits within a BitVector. + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index d9848cee0bfa904a90aea4626c3ee62c2cbb45b6..595176709806d54fc7c7c5ea301654717096b2d6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloVerifiedTestBase {}; +class HloReachabilityTest : public HloTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: @@ -48,7 +48,8 @@ TEST_F(HloReachabilityTest, Reachability) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); HloReachabilityMap reachability({a, b, c, d, e}); reachability.SetReachable(a, a); @@ -81,6 +82,130 @@ TEST_F(HloReachabilityTest, Reachability) { EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); } +TEST_F(HloReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloReachabilityMap::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + reachability->UpdateReachabilityThroughInstruction(exp); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + reachability->UpdateReachabilityThroughInstruction(negate); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + +TEST_F(HloReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloReachabilityMap::Build(computation); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f7e82fb1f88e856305f6f481a451d4cd64ba4acf..22c3c40a93a1ddcd36659483fcc79fede32dd2c3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloVerifiedTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -162,7 +162,7 @@ class HloRematerializationTest : public HloVerifiedTestBase { // Test rematerialization of a single computation produced by // MakeRematerializableComputation. TEST_F(HloRematerializationTest, SingleComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, module)); + /*memory_limit_bytes=*/14 * 1024, module.get())); EXPECT_TRUE(changed); // Root should not have changed. @@ -203,7 +203,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, module)); + /*memory_limit_bytes=*/20 * 1024, module.get())); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -225,7 +225,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. TEST_F(HloRematerializationTest, RematerializeAroundWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // bit lower (17KB) to force rematerialization of the entry computation. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, module)); + /*memory_limit_bytes=*/17 * 1024, module.get())); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -261,7 +261,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while. Both the entry computation and while body computation should have // computations rematerialized. TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, module)); + /*memory_limit_bytes=*/15 * 1024, module.get())); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -293,7 +293,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. TEST_F(HloRematerializationTest, RematerializeNestedComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // ~12K so pick something slightly larger. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, module)); + /*memory_limit_bytes=*/13 * 1024, module.get())); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -346,7 +346,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + // // rng + tanh + exp - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( @@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -420,7 +420,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // The value %bcast is live across each call of Subcomputation (which requires // 8KB) though the value is not used in the calls. Rematerializing %bcast // across these calls reduces peak memory use from ~20KB down to ~16KB. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -533,7 +533,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // (ie %bcast is used indirectly by %negate), otherwise the %negate operand // aliases %add_2. const bool indirectly_used = GetParam(); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 45c684d66752862eec301b8943d350804f070309..11994d99c93e9d51691e482a3e3233b06fb0d060 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -143,7 +143,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { // Do not unify subcomputations with different parameter shapes. TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f..1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloVerifiedTestBase { +class HloTfGraphBuilderTest : public HloTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 136824a33565d65663f1e484713c5180a762b25b..27fd685a69a0bbd95b1d8d266ce6177a6c557f55 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -755,6 +756,12 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } +Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { + return CheckShape( + get_size, ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimensions(0))); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -1331,6 +1338,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + if (crs->all_reduce_id().has_value()) { + TF_RET_CHECK(crs->all_reduce_id().value() > 0) + << "All reduce id must be greater than 0 for " + << crs->ToShortString(); + } + return Status::OK(); + } + Status Preprocess(HloInstruction* instruction) override { auto previous = instructions_by_name_.find(instruction->name()); TF_RET_CHECK(previous == instructions_by_name_.end()) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 83b6244d1be0e1eec66daabfcfd1be5a3c0131ac..9fbfd6a21c1f1148801000169046fbcbb37934fe 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -94,6 +94,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleGather(HloInstruction* gather) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleGetDimensionSize(HloInstruction* get_size) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index afe01e5487c3225815e01343d86e9fe894c2cde8..5ddfe0a944f04f070f9bdb81697425ee417ac15a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -35,7 +35,7 @@ namespace { using ::testing::HasSubstr; -// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: @@ -66,7 +66,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -85,7 +85,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -104,7 +104,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -138,7 +138,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -303,7 +303,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -327,7 +327,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index e103222b55faccf2d0286dce33c0f1ce5df01feb..90904ac00110457bcc3b8974816a7080c4ab89fc 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -90,20 +90,29 @@ string HumanReadableProfileBuilder::ToString() const { op.optimal_seconds < 0 ? "" : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), - op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), - op.transcendental_count <= 0 - ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + op.flop_count > 0 && nsecs > 0 + ? HumanReadableNumFlops(op.flop_count, nsecs) + : "", + op.transcendental_count > 0 && nsecs > 0 + ? HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + : "", bytes_per_sec, bytes_per_cycle, op.name); }; - float optimal_seconds_sum = 0.0; + double optimal_seconds_sum = 0; int64 total_flops = 0.; int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { if (op.optimal_seconds > 0) { - optimal_seconds_sum += op.optimal_seconds; + // An op can run faster than the estimated optimum. For example, we might + // estimate a fusion's speed by looking at the size of its operands and + // result, but perhaps the fusion doesn't read the entirety of all of its + // inputs. For the purposes of summing the instructions' optimal speeds, + // we treat the "optimum" as the smallest of either the estimated optimum + // and the actual speed. + optimal_seconds_sum += + std::min(double{op.optimal_seconds}, CyclesToSeconds(op.cycles)); } total_flops += std::max(op.flop_count, int64{0}); total_transcendentals += std::max(op.transcendental_count, int64{0}); @@ -114,7 +123,7 @@ string HumanReadableProfileBuilder::ToString() const { print_op({is_entry_computation_ ? "[total] [entry]" : "[total]", "[total]", /*category=*/"", total_cycles_, total_flops, total_transcendentals, - total_bytes, optimal_seconds_sum}, + total_bytes, static_cast(optimal_seconds_sum)}, /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. @@ -155,8 +164,10 @@ string HumanReadableProfileBuilder::ToString() const { entry.text = op.name; entry.short_text = op.short_name; entry.category_text = op.category; - entry.metric = - CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6; + // Ignore ops that run faster than the estimated optimal here, as we do + // when calculating optimal_seconds_sum. + entry.metric = std::max( + 0., CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6); total_discrepancy_in_microseconds += entry.metric; table.AddEntry(std::move(entry)); } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d5225b8012b68f851b2bfec219d736ba0d..cf6cf897fe11eda01ba6b22119bba34ac2bef8fe 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -18,19 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { +class ImplicitBroadcastRemoverTest : public HloTestBase { protected: ImplicitBroadcastRemover remover_; }; TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -41,15 +42,16 @@ TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -60,13 +62,13 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); @@ -76,6 +78,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -86,9 +89,9 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Subtract(op::Parameter(), @@ -98,6 +101,7 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); @@ -108,9 +112,9 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, @@ -120,6 +124,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { } TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); @@ -132,9 +137,9 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), @@ -147,6 +152,7 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { TEST_F(ImplicitBroadcastRemoverTest, TernaryScalarAndDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -159,9 +165,9 @@ TEST_F(ImplicitBroadcastRemoverTest, builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 2d03aebc1aca4c55cca588072233b7a18e70a306..20cc18f981574adf1d95c9f1f87c95634238db06 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloVerifiedTestBase { +class IndexedArrayAnalysisTest : public HloTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -61,12 +61,12 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { const string& root_expression, bool print_constants) { IndexedArrayAnalysis indexed_tensor_analysis; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN( - IndexedArrayAnalysis::Array* const array_result, - indexed_tensor_analysis.GetArrayFor( - module().entry_computation()->root_instruction())); + TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + m->entry_computation()->root_instruction())); string string_result = CanonicalizeWhitespace( indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index a85793e477459f33f9ca91d168e45989d3b869a6..7f2d7e7cffc6debaaf9b64fffc5a8a7037ecdaa3 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -155,6 +155,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGetDimensionSize: return true; } @@ -452,7 +453,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); @@ -566,7 +567,7 @@ bool InstructionFusion::MultiOutputFusionCreatesCycle( // A consumer operand may have been multii-output fused into a parallel // consumer and thus be missing from the oridinal reachability map. if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { - reachability_ = consumer->parent()->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(consumer->parent()); } return reachability_->IsReachable(a, b); }; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 4045e886dd9ae70166281048483b0f73a35db094..198bd7fce5f392e5e895b959523d4fe9cf208ba2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/platform/macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index da1ad90959dc0ab1a840b3390281ce9d4999651e..39904bd54b09a916d3e26e90c62cd6a202f9588d 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( @@ -361,7 +361,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary2, computation->root_instruction()); EXPECT_TRUE( @@ -385,7 +385,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_TRUE( diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index c9b40d3c6195f80a19272a0d98890049d02315b9..b0fc1af8b89d7327a00f77f471e90d143a92de7c 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -110,3 +110,5 @@ REGISTER_MODULE_INITIALIZER( // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, + interpreter_platform); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6b03394669858ef0ffdbdd1a2bad90e9df9fbcd9..a90411922205c0006159ff99f35a70138b1bee4f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2092,6 +2092,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 47bfca2fd6e1527b73e396151d3764867ac03697..2400b7bb7c409a4dcb33e6e8f4b409738510f3d6 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,20 +49,19 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloVerifiedTestBase { +class LayoutAssignmentTest : public HloTestBase { protected: - void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout, + void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, /*channel_constraints=*/channel_constraints); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m).status()); } - std::vector LayoutOf(HloModule* module, absl::string_view name) { + std::vector LayoutOf(HloModule* m, absl::string_view name) { auto minor_to_major = - FindInstruction(module, name)->shape().layout().minor_to_major(); + FindInstruction(m, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } @@ -91,7 +90,7 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { TEST_F(LayoutAssignmentTest, ComputationLayout) { // Verify the layouts of the root and parameter instructions of a computation // match the ComputationLayout for two different layouts. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); @@ -101,8 +100,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout layout = LayoutUtil::MakeLayout(minor_to_major); Shape shape(ashape); @@ -113,7 +112,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -131,8 +130,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); Shape col_major_shape(ashape); @@ -149,7 +148,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -160,7 +159,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { // Verify that the layout of the fused parameters in a fusion instruction // match that of the fusion operands. Other fused instructions should have no // layout. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); auto constant_literal1 = LiteralUtil::CreateR2WithLayout( @@ -180,8 +179,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); @@ -194,7 +193,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -229,13 +228,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { auto negate = builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kNegate, get_element0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -267,17 +266,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -302,11 +301,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { auto nested_tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, inner_tuple})); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = nested_tuple->shape(); *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); @@ -316,7 +315,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -332,9 +331,9 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module) + .Run(m.get()) .ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), @@ -361,9 +360,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -374,7 +372,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -403,8 +401,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { HloInstruction::CreateTranspose(bshape, log, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + auto computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -415,7 +413,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -439,9 +437,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { HloInstruction::CreateBroadcast(bshape, param, {1, 2})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); Shape input_shape_with_layout(ashape); Shape output_shape_with_layout(cshape); @@ -454,7 +452,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -488,9 +486,8 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tuple)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tuple)); ComputationLayout computation_layout(computation->ComputeProgramShape()); Shape param_shape_with_layout(f32_4); @@ -507,7 +504,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -558,9 +555,8 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); auto reshape = builder.AddInstruction( HloInstruction::CreateReshape(cshape, concatenate)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(reshape)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(reshape)); Shape param0_shape_with_layout(ashape); Shape param1_shape_with_layout(ashape); @@ -573,7 +569,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -593,11 +589,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -611,11 +607,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -681,12 +677,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - ParseAndVerifyModule(module_str); - + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); @@ -721,9 +717,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -735,19 +732,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -757,7 +754,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); @@ -784,7 +781,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { true_builder.AddInstruction(HloInstruction::CreateTuple({add})); } HloComputation* true_computation = - module->AddEmbeddedComputation(true_builder.Build()); + m->AddEmbeddedComputation(true_builder.Build()); auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); { @@ -800,14 +797,14 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = - module->AddEmbeddedComputation(false_builder.Build()); + m->AddEmbeddedComputation(false_builder.Build()); builder.AddInstruction(HloInstruction::CreateConditional( result_tshape, pred, tuple, true_computation, tuple, false_computation)); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -828,13 +825,13 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module).status(); + Status error_status = layout_assignment.Run(m.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -861,9 +858,10 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -873,12 +871,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(&module(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } @@ -897,17 +895,17 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] cross-replica-sum(gte), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -917,12 +915,12 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); - const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); } @@ -938,11 +936,12 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -966,11 +965,12 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -997,11 +997,12 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1028,11 +1029,12 @@ TEST_F(LayoutAssignmentTest, } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1050,11 +1052,12 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = @@ -1107,20 +1110,21 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); // Sanity check to verify that there's a layout mismatch. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Make sure that layout assignment did not magically eliminate the mismatch, // in which case the test didn't prove anything. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); } TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { @@ -1136,32 +1140,32 @@ ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { // and result layout should match that of the computation. { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); } { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ASSERT_THAT(root, op::CustomCall(op::Parameter())); ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); @@ -1179,24 +1183,24 @@ ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3 } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // The custom call should be partially encapsulated in kCopy instructions // because of the layout mismatches. - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); @@ -1211,18 +1215,18 @@ ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall())); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); } @@ -1238,25 +1242,25 @@ ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f } )"; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_parameter_layout(1) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout( ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); - ASSERT_THAT(module->entry_computation()->root_instruction(), + ASSERT_THAT(m->entry_computation()->root_instruction(), op::Copy(op::CustomCall(op::Tuple()))); const HloInstruction* custom_call = - module->entry_computation()->root_instruction()->operand(0); + m->entry_computation()->root_instruction()->operand(0); ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); } @@ -1273,36 +1277,34 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, // Try with a couple different layouts. In each case the custom calls operand // and result layout should match that of the computation. TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, + std::unique_ptr m, ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); - ComputationLayout computation_layout = module->entry_computation_layout(); + ComputationLayout computation_layout = m->entry_computation_layout(); *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - ExpectTupleLayoutIs(module->result_shape(), {{1, 0}, {1, 0}}); + ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}}); - const HloInstruction* custom_call = - FindInstruction(module.get(), "custom-call"); + const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call"); ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); } Status AssignLayoutsToComputation( - HloModule* module, - ChannelLayoutConstraints* channel_constraints = nullptr) { - if (!module->entry_computation_layout().result_layout().LayoutIsSet()) { - module->mutable_entry_computation_layout() + HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { + if (!m->entry_computation_layout().result_layout().LayoutIsSet()) { + m->mutable_entry_computation_layout() ->mutable_result_layout() ->SetToDefaultLayout(); } LayoutAssignment layout_assignment( - module->mutable_entry_computation_layout(), + m->mutable_entry_computation_layout(), LayoutAssignment::InstructionCanChangeLayout, channel_constraints); - return layout_assignment.Run(module).status(); + return layout_assignment.Run(m).status(); } TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { @@ -1325,16 +1327,16 @@ TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { auto add = b.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1)); b.AddInstruction(HloInstruction::CreateTuple({add, transpose})); - auto module = CreateNewVerifiedModule(); - module->AddEntryComputation(b.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(b.Build()); Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0}); Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1}); - *module->mutable_entry_computation_layout()->mutable_result_layout() = + *m->mutable_entry_computation_layout()->mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor})); const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0}); - ForceParameterLayout(module.get(), 0, r2_dim0major); - ForceParameterLayout(module.get(), 1, r2_dim0major); - TF_ASSERT_OK(AssignLayoutsToComputation(module.get())); + ForceParameterLayout(m.get(), 0, r2_dim0major); + ForceParameterLayout(m.get(), 1, r2_dim0major); + TF_ASSERT_OK(AssignLayoutsToComputation(m.get())); EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0)); EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(), diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 850501a4b5c521f5a5cc29658a04ae4bb638e14f..728a66b388f0f9af480ff88b5e96990a26e36af5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", ], @@ -197,14 +198,17 @@ cc_library( hdrs = ["sort_util.h"], deps = [ ":ir_array", + ":kernel_support_library", ":llvm_loop", ":llvm_util", ":loop_emitter", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2e5aebb74c29b809ae5c323b1912043d9f160d67..df78726166eea953b57e72a5a5fc81ee246aca34 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -83,10 +84,9 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b) { +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -260,6 +260,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, /*AddNull=*/false); } +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name) { + const int kNVPTXSharedMemoryAddrSpace = 3; + return new llvm::GlobalVariable( + *module, tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); +} + llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index f59baff263fe7184c6b0821c9dbd9eee205586a6..c604c7c870adf734a29017e6accbd159317a9548 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -101,10 +102,9 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b); +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -155,6 +155,11 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module); +// Allocates a tile of shared memory. +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name); + // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point // of the builder is set to the same place after calling this function diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 05ba4a40da413f0e774214e55ef69d023afc48e2..fd16af67fe99b4f440ad962b4b648a3b22c41dc6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -18,7 +18,9 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -28,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -39,147 +43,352 @@ namespace xla { namespace llvm_ir { namespace { -// Adds the inner comparison loop where we compare elements pointed to by -// 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, - const IrArray::Index& compare_keys_index, - const IrArray& keys_array, - const std::vector& values_arrays, - llvm::IRBuilder<>* b) { - // if (is_smaller_index && - // compare_keys[dimension_to_sort] < dimension_to_sort_bound) - llvm::Value* is_smaller_index = b->CreateICmpSLT( - keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = EmitIfThenElse( - b->CreateAnd(is_smaller_index, - b->CreateICmpSLT(compare_keys_index[dimension_to_sort], - keys_index.GetConstantWithIndexType( - dimension_to_sort_bound))), - "smaller_comparison_index", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_data.true_block, b); - auto key1 = keys_array.EmitReadArrayElement(keys_index, b); - auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); - auto compare_key1 = key1; - auto compare_key2 = key2; - auto key_type = keys_array.GetShape().element_type(); - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the sort - // has a predictable behavior in the presence of NaNs. Rather than using - // floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; + +// Adds the inner comparison loop body where we compare elements. +void EmitCompareLoopBody( + int64 iteration_bound, PrimitiveType key_type, int64 num_values, + llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type, + std::function read_element, + std::function + write_element, + llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { + auto index_typed_constant = [&](int64 value) { + return llvm::ConstantInt::get(index_type, value); + }; + // The 'xor_mask' determines which elements are compared against each other. + // Index 'current_keys_index' will be compared with 'current_keys_index' xor + // 'xor_mask'. This means that we will always compare a block of consecutive + // elements against elements from the adjacent block of the same size. When + // 'xor_mask' is a power of 2, it immediately identifies the size of such a + // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In + // that case, we essentially flip the last 'k' - 1 bits when computing the + // position of the element to compare to, so the block size is 2^(k - 1). + int64 block_size = xor_mask; + // Check if it is a value 2^k - 1. + if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) { + block_size = (xor_mask + 1) / 2; + } + auto current_keys_index = element_pair_index; + if (block_size == 1) { + // If the block size is 1, we take every second element and compare it to + // the next one. + current_keys_index = + b->CreateMul(current_keys_index, index_typed_constant(2)); + } else if (block_size * 2 < iteration_bound) { + // current_keys_index iterates through the 'left' elements of the element + // pairs to be compared. We first need to compute the comparison block to + // which the element belongs. The block id of that block is index / + // block_size. + auto block_id = + b->CreateUDiv(current_keys_index, index_typed_constant(block_size)); + // The index of the 'left' element within its block is simply the remainder + // when dividing by 'block_size'. + auto index_within_block = + b->CreateURem(current_keys_index, index_typed_constant(block_size)); + // The first element of the 'left' block of elements that is compared + // against elements from the adjacent 'right' block of elements is + // 'block_id' * (2 * 'block_size'). + auto first_element_in_block = + b->CreateMul(block_id, index_typed_constant(2 * block_size)); + current_keys_index = + b->CreateAdd(first_element_in_block, index_within_block); + } + auto compare_keys_index = + b->CreateXor(current_keys_index, index_typed_constant(xor_mask)); + // current_keys_index < compare_keys_index + llvm::Value* is_smaller_index = + b->CreateICmpSLT(current_keys_index, compare_keys_index); + // compare_keys_index < iteration_bound + llvm::Value* index_is_inbounds = b->CreateICmpSLT( + compare_keys_index, index_typed_constant(iteration_bound)); + llvm::Value* do_comparison = + needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds) + : b->getInt1(true); + + // if (is_smaller_index && index_is_inbounds) + KernelSupportLibrary ksl(b); + ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + auto key1 = read_element(0, current_keys_index); + auto key2 = read_element(0, compare_keys_index); + auto compare_key1 = key1; + auto compare_key2 = key2; + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } + // If key2 < key1 + ksl.IfReturnVoid( + "is_smaller_than", + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1), + [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); + }); +} + +void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index, + int64 dimension_to_sort, + int64 dimension_to_sort_bound, + PrimitiveType keys_type, + absl::Span xor_masks, + const std::vector& params, + const std::vector& param_shmem_buffers, + int64 tile_size, llvm::IRBuilder<>* b) { + KernelSupportLibrary ksl(b); + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); + llvm_ir::AddRangeMetadata(0, tile_size / 2, + llvm::cast(thread_id)); + thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), + /*isSigned=*/true, "thread.id.x"); + + auto copy_loop_body = + [&](std::function + read_or_write) { + auto value_one = tiled_keys_index.GetConstantWithIndexType(1); + auto current_keys_index = + b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); + // We want to copy two adjacent elements. We first check whether the + // first index position is within bounds. + ksl.IfReturnVoid( + "smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + auto cache_index = b->CreateShl(thread_id, value_one); + read_or_write(cache_index, current_keys_index); + // Increment to go the next index position. + current_keys_index = b->CreateAdd(current_keys_index, value_one); + // Here we check whether the next index position is within bounds. + ksl.IfReturnVoid( + "inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); + }); + }; + + // Copy operand tiles from the operand buffers to shared memory. + IrArray::Index keys_index = tiled_keys_index; + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = params[i].EmitReadArrayElement(keys_index, b); + b->CreateStore(value, + b->CreateGEP(param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), + cache_index})); + }); + } + // Wait until all reads have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + + // Now emit the bodies of the comparison loops. + auto read_element = [&](int64 operand, llvm::Value* index) { + return b->CreateLoad( + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + b->CreateStore( + value, + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + for (int64 xor_mask : xor_masks) { + // The index of the element pair to be compared within the tile stored in + // shared memory. We order the element pairs by the element with the smaller + // index. + auto element_pair_index = thread_id; + // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't + // need any bounds checks. + if (dimension_to_sort_bound % tile_size) { + // Otherwise we need a bounds check for the last tile. The last tile has + // size 'dimension_to_sort_bound' % 'tile_size'. + ksl.IfReturnVoid( + "is_last_tile", + b->CreateICmpUGE( + b->CreateMul(tiled_keys_index[dimension_to_sort], + tiled_keys_index.GetConstantWithIndexType(2)), + tiled_keys_index.GetConstantWithIndexType( + RoundDownToNearest(dimension_to_sort_bound, tile_size))), + [&]() { + EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, + params.size() - 1, element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b); + }, + [&]() { + EmitCompareLoopBody( + tile_size, keys_type, params.size() - 1, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + }); + } else { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + } + // Wait until all comparisons have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); } - auto comparison = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - // If key2 < key1 - auto if_smaller_data = - EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_smaller_data.true_block, b); - // Swap key1 with key2. - keys_array.EmitWriteArrayElement(keys_index, key2, b); - keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - for (const auto& values_array : values_arrays) { - // Also swap the values. - auto value1 = values_array.EmitReadArrayElement(keys_index, b); - auto value2 = values_array.EmitReadArrayElement(compare_keys_index, b); - values_array.EmitWriteArrayElement(keys_index, value2, b); - values_array.EmitWriteArrayElement(compare_keys_index, value1, b); + + // Copy the operand tiles back from shared memory to the operand buffers. + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = b->CreateLoad(b->CreateGEP( + param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); + params[i].EmitWriteArrayElement(keys_index, value, b); + }); } + // We should normally synchronize here to make sure all writes have happened. + // However the very next thing each thread does is reading 2 elements from the + // operand buffer and writing it into the same location in shared memory from + // which it previously copied it to the operand buffer, and we synchronize + // after this has happened. We can be sure that a thread always writes to the + // same location in shared memory because we have exactly tile_size / 2 many + // threads, and the linear index calculated by ParallelLoopEmitter uses + // linear_index = blockIdx.x * blockDim.x + threadIdx.x; } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions) { - const Shape& keys_shape = keys_array.GetShape(); + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, + const int64 tile_size) { + // Iterate through the keys shape in physical order, but skip the dimension to + // sort and make it the innermost loop which is the loop where the comparisons + // happen. In the dimension to sort, if we use tiling, we iterate through it + // in tiles of 64 elements each, so we use another loop that happens within + // one thread to process this tile worth of data (thereby combining several + // comparison stages of the bitonic sort algorithm because they all happen + // within those 64 elements and are therefore independent of the other + // comparisons). - // Create loop nests which loop through the operand dimensions. The sort - // dimension is handled in the innermost loop which performs the sorting. - ForLoopNest loop_nest(name, b); - IrArray::Index keys_index = - loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); - if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); + const Shape& keys_shape = keys_array.GetShape(); + int64 rank = ShapeUtil::Rank(keys_shape); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::vector dimensions_in_iteration_order(rank); + std::vector iteration_order_to_logical_order(rank); + int64 dim = 0; + for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) { + if (dimension != dimension_to_sort) { + dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension); + iteration_order_to_logical_order[dim++] = dimension; + } } + dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim; + iteration_order_to_logical_order[dim] = dimension_to_sort; - // 'compare_keys_index' is the index of the element that 'keys_index' should - // be compared to. - IrArray::Index compare_keys_index(keys_index.GetType()); - for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { - if (dimension != dimension_to_sort) { - compare_keys_index.push_back(keys_index[dimension]); - } else { - compare_keys_index.push_back(nullptr); + Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + dimensions_in_iteration_order); + std::vector params(1, keys_array); + params.insert(params.end(), values_arrays.begin(), values_arrays.end()); + + // Allocate shared memory for the tiled compare loop. + std::vector param_shmem_buffers(params.size(), nullptr); + if (xor_masks.size() > 1) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + for (int64 i = 0; i < params.size(); ++i) { + llvm::Type* tile_type = + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + params[i].GetShape().element_type(), module), + tile_size); + param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( + module, tile_type, absl::StrCat(name, "_tile_param_", i)); } } - // Naive C++ code for the inner compare loop: - // - // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { - // int64 j = i ^ xor_mask; - // if (i < j && j < dimension_to_sort_bound) { - // int64 min_key = std::min(keys[i], keys[j]); - // keys[j] = std::max(keys[i], keys[j]); - // keys[i] = min_key; - // } - // } - // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), - {dimension_to_sort_bound}); auto compare_loop_body_emitter = - [&](const IrArray::Index& compare_index) -> Status { - keys_index[dimension_to_sort] = compare_index[0]; - compare_keys_index[dimension_to_sort] = - b->CreateXor(compare_index[0], xor_mask); - EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_arrays, b); + [&](const IrArray::Index& tiles_index) -> Status { + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // /* emitted in EmitCompareLoopBody() */ + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + IrArray::Index keys_index(tiles_index.GetType(), rank); + for (int64 i = 0; i < rank; ++i) { + keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + } + if (xor_masks.size() > 1) { + EmitTiledCompareLoop(keys_index, dimension_to_sort, + dimension_to_sort_bound, keys_shape.element_type(), + xor_masks, params, param_shmem_buffers, tile_size, + b); + } else { + auto read_element = [&](int64 operand, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + return params[operand].EmitReadArrayElement(keys_index, b); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + keys_index[dimension_to_sort] = index; + params[operand].EmitWriteArrayElement(keys_index, value, b); + }; + EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), + values_arrays.size(), tiles_index[rank - 1], + xor_masks[0], tiles_index.GetType(), read_element, + write_element, b); + } return Status::OK(); }; - if (launch_dimensions != nullptr) { - TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, - compare_shape, - *launch_dimensions, b) - .EmitLoop(name)); - } else { - TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) - .EmitLoop(name)); - } - - // Set the IR builder insert point to the exit basic block of the outer most - // loop. This ensures later instructions are inserted after this loop nest. - b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); + return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape, + launch_dimensions, b) + .EmitLoop(name); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 2f3bcda2307bcbb35a03b9e71dbbe44e366b3820..556a217322d373ffd5e816dcf35888b546806633 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,13 +30,14 @@ namespace xla { namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, -// the inner compare loop will not be parallelized. +// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' +// contains only powers of 2, or values 2^k - 1 (k > 0). Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const std::vector& values_arrays, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions); + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 84059dd0f71ee8fc0a25703cbab2268d7dc149a8..fd18bfdc3e7f4b5f94237c554c3e6ca8bd065a35 100644 --- a/tensorflow/compiler/xla/service/map_inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using MapInlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(MapInlinerTest, MapMax) { @@ -59,12 +59,12 @@ TEST_F(MapInlinerTest, MapMax) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -93,12 +93,12 @@ TEST_F(MapInlinerTest, MapConstant) { HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -131,12 +131,12 @@ TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e..9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -257,7 +258,7 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, } void MultiOutputFusion::RecomputeReachability() { - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); } void MultiOutputFusion::UpdateReachability( @@ -317,9 +318,9 @@ bool MultiOutputFusion::Perform() { << instr2->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); } + Update(instr1, instr2); HloInstruction* ret = Fuse(instr1, instr2); set_is_fused(ret == instr1 ? instr2 : instr1); - Update(instr1, instr2); changed = true; VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" << ret->fused_instructions_computation()->ToString( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9508ab2ed1d38ec40983d8892ec8875b848fb21b..1c7583ece720f9e4d4b71a6279b976fed40e10cb 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c522e7ae23b734090f85d241bf365fccc37f0adb..c227106511c2c17b44569d3b696cd7d764226e81 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,20 +59,15 @@ string CanonicalPlatformName(const string& name) { /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { - se::MultiPlatformManager::PlatformMap platform_map; - se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( - [&platform_map](se::MultiPlatformManager::PlatformMap* map) { - platform_map = *map; - return se::port::Status::OK(); - }); - if (platform_map.empty()) { + std::vector all_platforms = + se::MultiPlatformManager::AllPlatforms(); + if (all_platforms.empty()) { LOG(WARNING) << "no executor platforms available: platform map is empty"; } // Gather all platforms which have an XLA compiler. std::vector platforms; - for (auto& platform_pair : platform_map) { - auto* platform = platform_pair.second; + for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { platforms.push_back(platform); @@ -222,8 +217,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // fix the number of devices to one. However we do let the user override // this behavior to help run tests on the host that run models in parallel // across multiple devices. - device_count = legacy_flags::GetDebugOptionsFromFlags() - .xla_force_host_platform_device_count(); + device_count = + GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 688cceff0cd10df62a4093f00ad3331ca77652e0..b70cb7057477a338bfb36ebab76237b30d018e41 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -111,7 +111,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( VLOG(2) << "Adding to operand " << i << ": " << operand; if (!is_valid_shape(operand->shape())) { - VLOG(2) << "Skipped: value is not an F32 vector"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } @@ -168,7 +168,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( << instruction->ToString(); if (!is_valid_shape(instruction->shape())) { - VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 0b4e82e8d606cf2cacfab42d07c2201939d5e10b..76c6a87f176ec9c6f8e49c25278c6dad703e3c7c 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass { // equivalent behavior can be obtained by adding ReducePrecision // instructions after the instructions that pull the F32 arrays out of // the tuples. - // - // TODO(b/64093391): Remove the IsScalar check once this won't cause - // failures on the GPU backend if the ReducePrecision instruction ends up - // inserted between a scalar constant and the init_value argument of a - // Reduce operation. - return shape.element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(shape); + return shape.element_type() == PrimitiveType::F32; } // Is this instruction one such that following or preceding it with a new diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 69e4b534bd8e3aeab8b729f3e594a10b4368f15f..16fa80d53e7dc3456b0dade8b92cf101b3e0a33d 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,34 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); +} + +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -84,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -113,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -146,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -178,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -215,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -242,7 +269,7 @@ TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -268,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -294,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -321,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -349,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -375,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -411,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -458,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index fcf269eee925c2ddb7511d70e71bd815e4b8c24a..341659b15c4c7355d39739ee171a4a749d87e929 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,9 +34,10 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase {}; +class ReshapeMoverTest : public HloTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -50,12 +51,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -74,6 +75,7 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { // Verifies that the reshape is not moved, since rng0 is trivially reshapable // and therefore there is no nontrivial reshapes to move. TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( @@ -92,18 +94,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); } TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -117,12 +120,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -130,6 +133,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -143,11 +147,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -177,6 +181,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { // | // reshape4 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction( @@ -196,12 +201,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); @@ -221,6 +226,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { // Verifies that the reshape0 does not sink below add, because param1 is not // trivially reshapable nor is a Reshape/Transpose. TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -232,11 +238,11 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -257,6 +263,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // Verifies that we don't unnecessarily sink reshapes, which are in fact // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -275,12 +282,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -309,6 +316,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { // // (note that reshape1 here is trivial). TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -320,12 +328,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -348,6 +356,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { // For now we treat it as non-trivial, so we verify that we don't sink the // reshapes in this case. TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -362,12 +371,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); @@ -376,6 +385,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -389,14 +399,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -405,6 +415,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); @@ -423,13 +434,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -438,6 +449,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { } TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto pred_shape = ShapeUtil::MakeShape(PRED, {}); @@ -452,11 +464,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -477,6 +489,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { // // We expect reshape{0,1} AND reshape{2,3} to be lifted. TEST_F(ReshapeMoverTest, MultiplePasses) { + auto m = CreateNewVerifiedModule(); auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); @@ -500,14 +513,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -526,11 +539,11 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Transpose(op::Multiply())); } @@ -555,8 +568,8 @@ TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_FALSE(changed); } @@ -580,10 +593,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); } @@ -597,10 +610,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Add())); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 6f9094a5c2e882f4bc1531efdef654a6afa2ddb6..75f7413b3c303da620c2815c83e03324148c0961 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -292,7 +292,7 @@ StatusOr> Service::CreateModuleConfig( config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { - config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config->set_debug_options(GetDebugOptionsFromFlags()); } if (execute_backend_ != nullptr && @@ -760,38 +760,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - ExecuteGraphParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - -Status Service::PickParallelResponse( - const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { - // The "result device" selection is a bit hacky, but better than assuming it - // is device 0. We have b/76035356 for restructuring the client API to clean - // up the current asymmetries and support more functionalities. - for (int64 i = 0; i < parallel_result.responses_size(); ++i) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica( - parallel_result.responses(i).output(), 0)); - const Shape& shape = buffer->on_host_shape(); - if (!ShapeUtil::IsEmptyTuple(shape)) { - *result = parallel_result.responses(i); - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return Status::OK(); - } - } - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - VLOG(1) << "Defaulting to device 0 result"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -836,10 +804,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute-graph request"; - +Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { + VLOG(1) << "running compile request"; if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -847,22 +813,21 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return InvalidArgument("programe shape may not be empty"); } - // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); + return InvalidArgument( + "The compile request does not support multiple device handles."); } - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - + std::vector argument_shapes; + absl::c_transform(arg->input_shape_with_layout(), + std::back_inserter(argument_shapes), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(arg->computation().host_program_shape(), - replicated_arguments.front(), - arg->execution_options())); + argument_shapes, &arg->execution_options())); + VLOG(3) << "Compile created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -871,6 +836,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); + + VLOG(1) << "successfully completed 'compile' request"; + return Status::OK(); +} + +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { + VLOG(1) << "running execute request"; + if (!arg->has_handle()) { + return InvalidArgument("execution handle should not be empty"); + } + TF_ASSIGN_OR_RETURN(auto executable, + compilation_cache_.LookUp(arg->handle())); + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + // Check that the replicated_arguments has the same shape and layout as the + // module config used when creating the exectuable. + const int64 num_module_args = + executable->module_config().entry_computation_layout().parameter_count(); + if (num_module_args != arg->arguments_size()) { + return InvalidArgument( + "The executable expects %lld arguments, but sees %lld.", + num_module_args, arg->arguments_size()); + } + for (int64 i = 0; i < num_module_args; i++) { + const Shape& shape_module = + executable->module_config().entry_computation_layout().parameter_shape( + i); + const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape(); + if (!ShapeUtil::Equal(shape_module, shape_arg)) { + return InvalidArgumentStrCat( + "The executable exepcts the ", i, "th argument in shape ", + ShapeUtil::HumanStringWithLayout(shape_module), " but sees ", + ShapeUtil::HumanStringWithLayout(shape_arg)); + } + } + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( execute_backend_->default_stream_executor())); @@ -884,9 +891,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + arg->computation().name(), result->mutable_profile())); + ExecuteAndRegisterResult(executable.get(), replicated_arguments, + execute_backend_.get(), + "result of " + executable->module().name(), + result->mutable_profile())); if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN( @@ -898,7 +906,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - VLOG(1) << "successfully completed 'execute-graph' request"; + VLOG(1) << "successfully completed 'execute' request"; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 8cf1a7b9f01fbb3572c6849c8b18e14174ced89f..11e1a79552fbd944ab28da129b08cfe676fb08e9 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" +#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" @@ -90,11 +91,14 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Executes a computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. - // Returns global data output and execution timing. - Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + // Compiles a computation into an executable. The request contains the whole + // computation graph. Returns the handle to the executable. + Status Compile(const CompileRequest* arg, CompileResponse* result) override; + + // Executes an executable with the provided global data passes as immutable + // arguments. The request contains the handle to the executable. Returns + // global data output and execution timing. + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each @@ -179,10 +183,6 @@ class Service : public ServiceInterface { absl::Span arguments, const ExecutionOptions& execution_options); - // Picks a parallel response and fills the result. - Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, - ExecuteResponse* result); - // Prepare the executors for executing parallel. StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, @@ -254,11 +254,6 @@ class Service : public ServiceInterface { Backend* backend, absl::Span device_handles, absl::Span result_tags, ExecutionProfile* profile); - // Executes a single computation which has more than one target device. - // The N devices are expected to all return an empty tuple, but one, which - // will be the result of this computation. - Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. @@ -281,6 +276,9 @@ class Service : public ServiceInterface { ServiceOptions options_; + // Cache containing previously built Executables. + CompilationCache compilation_cache_; + // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2f8f092303ed1821d9bff021da0e835f1878f5ed..61a60ef9efa72f53fa2c6730ca297ddfe01c56ba 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2031,6 +2031,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } +/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( + const Shape& shape, int64 dimension) { + if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", + dimension); + } + + return ShapeUtil::MakeShape(S64, {}); +} + /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { @@ -2833,6 +2843,15 @@ Status ValidateScatterDimensionNumbers( } } + // Validate window size. + auto window_size = dim_numbers.update_window_dims_size() + + dim_numbers.inserted_window_dims_size(); + if (window_size != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Scatter op has window of size %d; doesn't match operand of rank %d.", + window_size, ShapeUtil::Rank(operand_shape)); + } + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. if (dim_numbers.scatter_dims_to_operand_dims_size() != scatter_indices_shape[dim_numbers.index_vector_dim()]) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index cd4e5ab52ca5e33424f2b78f83cc94961b254493..31ef4b2e41078f87731a1eff58e37409a6004ba4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -291,6 +291,9 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + static StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64 dimension); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 7b65e8c1c9d2bc730c6c8550e9265b69fdde71cf..4639e32db4d59080a9e85e46983fac61d9e76be9 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -2673,5 +2673,23 @@ TEST_F(ScatterGatherShapeInferenceTest, << statusor.status(); } +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_InsufficientWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Scatter op has window of size 4; doesn't match operand of rank 5.")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 56952e3adae59656605a12fd499162504a2a3379..28a30b5ee2dbcb5012804578d4d037c241045309 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -157,4 +157,23 @@ void ScopedShapedBuffer::Deallocate() { } } +ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { + const xla::Shape& sub_on_host_shape = + xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); + const xla::Shape& sub_on_device_shape = + xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); + + ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, + memory_allocator(), device_ordinal()); + auto src_it = buffers().find(index); + auto dst_it = output.buffers().begin(); + while (dst_it != output.buffers().end()) { + dst_it->second = src_it->second; + src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0); + ++src_it; + ++dst_it; + } + return output; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e1d26da4a20c0105be304b1a34c81515fcdc6b7f..f5210c9cfa6b29853bcd0f5bfd581ee3e116a509 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -176,6 +176,11 @@ class ScopedShapedBuffer : public ShapedBuffer { // It's the caller's job to ensure that the memory contained therein is freed. TF_MUST_USE_RESULT ShapedBuffer release(); + // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer + // that holds ownership of the subtree. Sets the buffers corresponding to the + // subtree to null in 'this'. + ScopedShapedBuffer TakeSubTree(ShapeIndexView index); + protected: void Deallocate(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index d69e6362e91e4696dab3c46d99a981c67b593a1c..ca64bd3c8dd2baa686db2b85c937a034b37ab22b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -107,5 +109,79 @@ TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { // TestAllocator's destructor checks that all memory was freed. } +TEST(ScopedShapedBufferTest, TestTakeSubTree) { + TestAllocator allocator; + + Shape s = ShapeUtil::MakeShape(F32, {1}); + s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); + s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); + + ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + sb.buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + TF_ASSERT_OK_AND_ASSIGN( + OwningDeviceMemory m, + allocator.Allocate(/*device_ordinal=*/0, /*size=*/77)); + *buffer = m.Forget(); + }); + ShapeTree buffers = sb.buffers(); + + // Takes a subtree out of 'sb', and verifies the buffers are as expected. + xla::ShapeIndex subtree_index = {1}; + ScopedShapedBuffer output = sb.TakeSubTree(subtree_index); + + output.buffers().ForEachElement([&](const xla::ShapeIndex& sub_index, + const se::DeviceMemoryBase& buffer) { + xla::ShapeIndex orig_index = subtree_index; + for (int i : sub_index) { + orig_index.push_back(i); + } + EXPECT_TRUE(buffers.find(orig_index)->second.IsSameAs(buffer)); + }); + sb.buffers().ForEachElement( + [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + if (ShapeIndexView(index).StartsWith(subtree_index)) { + EXPECT_TRUE(buffer.is_null()); + } else { + EXPECT_TRUE(buffers.find(index)->second.IsSameAs(buffer)); + } + }); +} + +// Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_TakeSubTree(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + TestAllocator allocator; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)shaped_buffer.TakeSubTree(/*index=*/{fan_out / 2}).release(); + } + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_TakeSubTree) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 79b5c09abb355cd067a4891af558c8c44d80d88e..7a565bf076847a4a5f7c98635785c80d86df152d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewModule("fuse_with_constant_operands"); + auto module = CreateNewUnverifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewUnverifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index d9ebebf74ed846aa05326a4df72019ef3e71ad88..10ef2d38fa21c3e93c270535bc99b2f76435337d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -48,7 +48,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { } void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); module_->AddEntryComputation(std::move(computation)); } @@ -809,7 +809,7 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { class PointsToAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -1176,7 +1176,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -1211,7 +1211,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 516754e2110ee50a597818c4a8bcfbfbb76c5cec..65b0f8c804475d8f22fff9798e79c9881a51f1f1 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloVerifiedTestBase { +class TupleSimplifierTest : public HloTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -65,10 +65,10 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, scalar_shape_, "param2")); builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -78,10 +78,10 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { HloInstruction::CreateParameter(0, tuple_shape_, "param")); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -98,12 +98,12 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), gte); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -125,13 +125,13 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -157,12 +157,12 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), element); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -182,12 +182,12 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -207,19 +207,19 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { // Verify that the root computation can be excluded - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloInstruction* p0; HloInstruction* p1; @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module, /*change_expected=*/true, /*exclude_entry=*/true); + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 541b117e0299c94de330604ec5c16e20f07c425f..68e2569f66bea9ec1223e454d1ead0efc7b9498e 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" namespace xla { @@ -229,4 +232,96 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, return nullopt; } +// If the only user of this instruction is a get-tuple-element, return that +// get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may +// get a false negative if there are several copies of the same GTE, or there +// are unused GTEs, but we can live with this. +static HloInstruction* GetOnlyGTE(HloInstruction* inst) { + if (inst->user_count() != 1) { + return nullptr; + } + + HloInstruction* user = inst->users().back(); + if (user->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; + } + return user; +} + +optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { + // If we know the exact trip count, it's also the upper bound. + auto exact_trip_count = ComputeWhileLoopTripCount(while_op); + if (exact_trip_count) { + VLOG(2) << "Loop has exact trip count."; + return exact_trip_count; + } + + // There is one more case we know how to handle. If the loop condition only + // looks at one element of the tuple, and the loop body sets this element to a + // constant, there are two options: + // 1) Evaluating the condition on this constant returns true. In this case, + // the loop either executes 0 times, or is an infinite loop, depending on the + // init value. + // 2) Evaluating the condition on this constant returns false. In this case, + // the loop executes 0 or 1 times, depending on the init value. This means + // that, regardless of the init value, the upper bound on the trip count is 1. + + // Check whether the condition depends on a single parameter, and find out + // which. + auto* while_cond = while_op->while_condition(); + auto* while_cond_param = while_cond->parameter_instruction(0); + auto* cond_gte = GetOnlyGTE(while_cond_param); + if (!cond_gte) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // Now check whether this gets set to a constant by the while body. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(3) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + int64 indvar_index = cond_gte->tuple_index(); + auto* while_body_indvar = while_body_root->operand(indvar_index); + if (while_body_indvar->opcode() != HloOpcode::kConstant) { + VLOG(3) << "While body does not set the IV to a constant: " + << while_body_indvar->ToString(); + return nullopt; + } + + // We have a constant. Evaluate the condition on this constant. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), + /*dest_shape_index=*/{indvar_index}, + /*src_shape_index=*/{})); + StatusOr eval_result = + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + + if (!eval_result.ok()) { + VLOG(2) << "Couldn't evaluate while loop condition."; + return nullopt; + } + + Literal cond_result_pred = std::move(eval_result.ValueOrDie()); + CHECK(ShapeUtil::Equal(cond_result_pred.shape(), + ShapeUtil::MakeShape(PRED, {}))); + + // Per the explanation above, if the evaluated condition returns false, the + // loop executes at most once. + bool cond_returns_true = cond_result_pred.GetFirstElement(); + if (!cond_returns_true) { + VLOG(2) << "Upper bound on the trip count is 1"; + return 1; + } + + VLOG(2) << "Loop has no known upper bound on the trip count."; + return nullopt; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf497f4892b95c927379411468a66d8961465413..ac69a727bd6b403672a676400993fb7d8afc0a55 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -28,6 +28,10 @@ namespace xla { absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, int64 max_value_returned = 128); +// Returns an upper bound on the trip count of the loop if it's statically +// known, nullopt otherwise. +absl::optional ComputeWhileLoopTripCountUpperBound( + HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1da0fbeac89a93eaaef893e5f25dd3b87cc1d5d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class WhileLoopAnalysisTest : public HloTestBase {}; + +TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(-1) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); +} + +TEST_F(WhileLoopAnalysisTest, NoUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(42) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, ExactBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + index = s32[] get-tuple-element(p_body), index=1 + one = s32[] constant(1) + inc = s32[] add(index, one) + ROOT root = (f32[2], s32[]) tuple(val, inc) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] less-than(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] constant(0) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 067cfcc17d65860a249de4d9e31703df12091d3a..8b381dec07397c1427e98bc30511ac21dc577610 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -46,8 +46,9 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( return Status::OK(); } -StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); const HloInstruction& init_value = *while_instr->operand(0); @@ -57,24 +58,48 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( bool changed = false; - for (HloInstruction* invariant_gte : - WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - int64 index = invariant_gte->tuple_index(); + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64 index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); - // Should have at least one user that's not while_body_root. - if (invariant_gte->user_count() <= 1) { + // Original value should be a constant. + if (invariant_value.opcode() != HloOpcode::kConstant) { continue; } - if (invariant_value.opcode() == HloOpcode::kConstant) { - auto* constant_instr = + // Sink into the while_body. + // Should have at least one user that's not while_body_root. + if (invariant_body_gte->user_count() > 1) { + HloInstruction* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_gte, constant_instr, while_body->root_instruction(), + invariant_body_gte, constant_instr, while_body->root_instruction(), index)); changed = true; } + + // Check if there is a corresponding GTE in while_conditional. + auto it = conditional_gte_index_to_insts.find(index); + if (it == conditional_gte_index_to_insts.end()) { + continue; + } + + for (HloInstruction* invariant_cond_gte : it->second) { + // Should have at least one user. + if (invariant_cond_gte->user_count() > 0) { + HloInstruction* constant_instr = while_cond->AddInstruction( + invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR( + invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); + changed = true; + } + } } return changed; @@ -115,10 +140,8 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { } for (HloInstruction* while_instr : while_instrs) { - // We only sink into while loop bodies, but this can be extended to - // transform conditions as well. TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileBody(while_instr)); + TrySinkingConstantsIntoWhileLoop(while_instr)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 577bad6c7062d2ee40271e407e8eed7655fa13bf..a866bc1264b4013bb7530b5e02b546e6f78d676b 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { // Sinks while loop invariant values that happen to be constants into the while -// loop body. This is probably not a win in isolation but may unlock further -// optimizations like constant folding. +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like constant folding. // // state = (..., const, ...) // while (pred(state)) { @@ -46,22 +46,19 @@ namespace xla { // tuple trivially loop invariant. WhileLoopSimplifier will later get rid of // `v`. // -// We only sink into while loop bodies, but this can be extended to transform -// conditions as well. -// // TODO(b/79121449): We should also sink broadcasts of constants. class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; absl::string_view name() const override { - return "while-loop-invariant-code-motion"; + return "while-loop-constant-sinking"; } StatusOr Run(HloModule* module) override; private: - StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); + StatusOr TrySinkingConstantsIntoWhileLoop(HloInstruction* while_instr); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index d17b86fab5b14d13250a03fc8f74abb9661ed5ce..75d406435b6f58faecc86b82c33e9e2dd6bccbea 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -242,5 +242,178 @@ ENTRY entry { } } } + +TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1 + ROOT root = (f32[],f32[]) tuple(add, p_body.1) +} + +condition { + p_cond = (f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + while_init = (f32[],f32[]) tuple(const_0, const_1) + ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[],(f32[],f32[])) parameter(0) + p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0 + p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1 + p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0 + add = f32[] add(p_b.0, p_b.1.0) + ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1) +} + +condition { + p_c = (f32[],(f32[],f32[])) parameter(0) + p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 + p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 + p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 + ROOT result = pred[] less-than(p_c.0, p_c.1.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) + ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::Lt(_, op::GetTupleElement(op::Constant()))); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); + for (const HloInstruction* inst : while_condition->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add.0 = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + add.1 = f32[] add(p_body.1, const) + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.0 = pred[] less-than(p_cond.0, p_cond.2) + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + ROOT result = pred[] and(lt.0, lt.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(0) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 9795b2830b6d9add82b89ac76b5438ddc3d2bfe8..b7c28bfac7889b788645360366d1419eb80e64de 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" @@ -143,6 +144,12 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( string while_instr_name = while_instr->ToString(print_no_metadata); VLOG(2) << "Trying to hoist from " << while_instr_name; + auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr); + if (maybe_upper_bound && *maybe_upper_bound <= 1) { + VLOG(2) << "Loop has a trip count of at most 1, skipping."; + return false; + } + HloComputation* while_body = while_instr->while_body(); // Maps instructions in the while body to instructions hoisted outside the @@ -180,6 +187,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return false; } + // LICM in the presence of domain instructions is complex, bail. + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + return false; + } + } + // instructions_to_replace[i] is hoisted into a loop invariant instruction // replacement_instructions[i]. std::vector instructions_to_replace; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b713c438bd7fcb2053709b0624f58ed..046ccb2d3f29c2141ade5275d043875e3e278582 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -26,7 +26,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { +class WhileLoopInvariantCodeMotionTest : public HloTestBase { public: // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. @@ -58,6 +58,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( } TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -76,19 +77,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -100,6 +100,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { } TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -135,19 +136,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -173,6 +173,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistTriviallyLoopVaryingComputation) { // Basic negative test: the add expression is not loop invariant. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -189,20 +190,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); @@ -210,6 +211,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -228,25 +230,26 @@ TEST_F(WhileLoopInvariantCodeMotionTest, builder.AddInstruction( HloInstruction::CreateTuple({gte_1, gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = @@ -267,7 +270,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -277,14 +280,14 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -294,6 +297,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the // bitcast either. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); @@ -317,7 +321,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -327,15 +331,15 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -346,6 +350,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); Shape while_shape = @@ -367,21 +372,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -396,6 +400,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -416,22 +421,23 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - while_body = module().AddEmbeddedComputation(builder.Build()); + while_body = m->AddEmbeddedComputation(builder.Build()); } HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -439,7 +445,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { HloComputation::Builder builder(TestName() + ".passthrough"); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + HloComputation* result = m->AddEmbeddedComputation(builder.Build()); result->AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); @@ -450,11 +456,11 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } @@ -482,14 +488,14 @@ ENTRY entry { )"; TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN( bool simplified_loop, - WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get())); EXPECT_TRUE(simplified_loop); - HloComputation* while_body = module().GetComputationWithName("wide.body"); + HloComputation* while_body = m->GetComputationWithName("wide.body"); ASSERT_NE(while_body, nullptr); // We expect the while body to be the equivalent of: @@ -523,10 +529,44 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { } TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0) + val.0 = f32[2] get-tuple-element(p_body), index=0 + val.1 = f32[2] get-tuple-element(p_body), index=1 + add = f32[2] add(val.0, val.1) + const = s32[] constant(-1) + ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const) + } + + condition { + p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=3 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1) + ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(module.get())); EXPECT_FALSE(simplified_loop); } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 630d71e5ca25e9d282ce6283284a32d6f725a193..6f924a29d8a3ac60abe98efd2e82ae7343c7de47 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -20,40 +20,14 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { using absl::optional; - -// Determines whether the given instruction is a send/recv node, or has a -// subcomputation which contains a send/recv node. -static bool IsOrContainsSendOrRecv(const HloInstruction* instr); - -// Determines whether the given computation contains a send or recv node. -static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto* instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr)) { - return true; - } - } - return false; -} - -static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kSendDone || - instr->opcode() == HloOpcode::kRecv || - instr->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (const auto& subcomp : instr->called_computations()) { - if (ContainsSendOrRecv(subcomp)) { - return true; - } - } - return false; -} +using hlo_query::ContainsInstrWithOpcode; // Tries to remove elements in a while loop's tuple that aren't used within the // loop. @@ -253,7 +227,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond), /*extras=*/{}); + make_while_computation_replacements(while_cond)); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -266,8 +240,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements), - /*extras=*/{}); + while_body->CloneWithReplacements(std::move(while_body_replacements)); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -458,6 +431,180 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { return changed_cond || changed_body; } +// Converts a flat list of instructions into a tuple of the desired shape. For +// example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns +// a tuple of value ((A, B), C). +// +// desired_shape must be a tuple. (This precondition allows us to return a +// unique_ptr rather than a raw ptr.) +static std::unique_ptr UnflattenTupleInstr( + absl::Span instrs, const Shape& desired_shape, + std::vector>* new_instrs) { + CHECK(ShapeUtil::IsTuple(desired_shape)) + << ShapeUtil::HumanString(desired_shape); + + // For each child shape in `desired_shape`, slice out the correct number of + // `instrs` and call UnflattenTupleInstr recursively. At each step we remove + // elements from `instrs` so that it only contains instructions we have not + // yet processed. + std::vector elems; + for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { + const Shape& subshape = desired_shape.tuple_shapes(i); + if (!ShapeUtil::IsTuple(subshape)) { + elems.push_back(instrs[0]); + instrs.remove_prefix(1); + continue; + } + + // Count the number of leaf nodes underneath desired_shape[i]. + int64 num_leaves = 0; + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + ++num_leaves; + } + }); + + std::unique_ptr subinstr = + UnflattenTupleInstr(instrs.subspan(0, num_leaves), + desired_shape.tuple_shapes(i), new_instrs); + elems.push_back(subinstr.get()); + new_instrs->push_back(std::move(subinstr)); + instrs.remove_prefix(num_leaves); + } + return HloInstruction::CreateTuple(elems); +} + +// Builds a vector whose elements are the values in the flattened tuple for +// `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the +// vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C). +static std::vector GetFlatTupleElems( + HloInstruction* instr, + std::vector>* new_instrs) { + const auto& shape = instr->shape(); + if (!ShapeUtil::IsTuple(shape)) { + return {instr}; + } + std::vector elems; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + new_instrs->push_back( + HloInstruction::CreateGetTupleElement(subshape, instr, i)); + auto* gte = new_instrs->back().get(); + auto flattened_subshape = GetFlatTupleElems(gte, new_instrs); + elems.insert(elems.end(), flattened_subshape.begin(), + flattened_subshape.end()); + } + return elems; +} + +static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + if (!ShapeUtil::IsNestedTuple(while_shape)) { + return false; + } + + // Cowardly refuse to perform this optimization in the presence of kDomain + // instructions, which may reference other instructions in the loop and + // therefore make this complicated. + if (ContainsInstrWithOpcode(while_body, {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_cond, {HloOpcode::kDomain})) { + return false; + } + + std::vector flattened_shape_elems; + ShapeUtil::ForEachSubshape(while_shape, + [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + flattened_shape_elems.push_back(s); + } + }); + Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto nested = [&](HloInstruction* instr) { + std::vector gtes; + const Shape& flat_shape = instr->shape(); + for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) { + gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement( + flat_shape.tuple_shapes(i), instr, i))); + } + auto nested_instr = + UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs); + CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape)) + << ShapeUtil::HumanString(nested_instr->shape()) << " vs " + << ShapeUtil::HumanString(while_shape); + return nested_instr; + }; + + auto flattened = [&](HloInstruction* instr) { + return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs)); + }; + + // Create a new while-condition computation, where parameter 0 has flat shape + // but all uses of it go through the nested shape. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Create a new while-body computation, where parameter 0 has a flat shape and + // all uses of it go through the nested shape, and where the root has a flat + // shape constructed from the old nested root. + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_body->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + flattened(add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile( + flattened_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + computation->AddInstruction(flattened(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -478,32 +625,46 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { for (HloInstruction* while_op : while_ops) { // We can't remove while loops that contain send/recv nodes, because we rely // on the particular loop structure around the node matching on the send and - // recv sides. Removing dead while params requires us to remove the loop + // recv sides. Other while simplifications require us to remove the loop // and replace it with a new one, so we can't do that either. - if (ContainsSendOrRecv(while_op->while_body()) || - ContainsSendOrRecv(while_op->while_condition())) { + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { VLOG(2) << "Not attempting to simplify while loop because it contains a " "send/recv node: " << while_op->ToShortString(); continue; } - StatusOr result = TryPropagateConstant(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); + changed |= result; + + TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); + changed |= result; + if (result) { + // Don't continue simplifying after successfully removing the while loop + // -- that would result in use-after-free nastiness. + continue; + } - result = TryRemoveWhileLoop(while_op); - TF_RETURN_IF_ERROR(result.status()); - if (result.ValueOrDie()) { - changed = true; - // Don't try to remove dead while params after successfully removing the - // while loop -- that would result in use-after-free nastiness. + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); + changed |= result; + if (result) { + // Successfully flattening nested tuples results in us cloning and + // replacing the while loop, meaning that `while_op` is no longer valid. continue; } - result = TryRemoveDeadWhileParams(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; + if (result) { + // Successfully removing dead while params results in us cloning and + // replacing the while loop, meaning that `while_op` is no longer valid. + continue; + } } XLA_VLOG_LINES(3, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 0bc5a0107bbcfb3b29a01d593fb79b89a863e49b..a378f179c63c788cd205ddbb784dee0e6b2106d7 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,11 +25,22 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. +// // - A while loop with static trip count of 1 is replaced by its body (sans // loop). +// // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // +// - If the while loop's parameter is a nested tuple, it's flattened to a +// single-level tuple. This is good because it usually reduces the number of +// kTuple instructions, but also because it unlocks additional optimizations +// (e.g. removing unused loop parameters). +// +// Flattening nested while loop tuples adds a whole mess of likely unnecessary +// kGetTupleElement and kTuple operations to the graph. We expect that tuple +// simplifier will be run afterwards. +// class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 1c892ba179ec67ccc9dbfe93d925551d6977ba15..05005e0b262a50cd40e004deac4c450a2e257308 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,9 +17,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -27,18 +29,21 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopSimplifierTest : public HloVerifiedTestBase { +class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. - void MakeModuleWithSimpleLoop(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoop(int num_iters); // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to // the loop-condition through an element of a tuple which is the // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string_template = R"( HloModule SimpleLoop SimpleLoop.body { @@ -67,10 +72,11 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( int num_iters) { string hlo_string_template = R"( HloModule SimpleLoopWithIndirectLoopBound @@ -104,60 +110,55 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/2); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithControlDependencySimplifiedDependencyPreserved) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -166,9 +167,8 @@ TEST_F(WhileLoopSimplifierTest, // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -179,13 +179,12 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -194,7 +193,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and @@ -202,16 +201,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A non-tuple shaped loop shouldn't be simplified or crash the compiler. @@ -236,8 +234,8 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A while loop that does nothing else besides swapping tuple elements @@ -268,8 +266,8 @@ TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each @@ -297,8 +295,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. @@ -320,8 +318,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't @@ -348,8 +346,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be @@ -390,16 +388,15 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { } )"; - ParseAndVerifyModule(hlo_string); - HloModule* the_module = &module(); - EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. HloInstruction* new_while_op = - *std::find_if(the_module->entry_computation()->instructions().begin(), - the_module->entry_computation()->instructions().end(), + *std::find_if(m->entry_computation()->instructions().begin(), + m->entry_computation()->instructions().end(), [&](const HloInstruction* instr) { return (instr->opcode() == HloOpcode::kWhile && instr->name() != "while"); @@ -440,8 +437,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, @@ -473,8 +470,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { @@ -505,8 +502,65 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { + const string hlo_string = R"( + HloModule Test + Body { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ta = (s32[1]) get-tuple-element(param), index=0 + a = s32[1] get-tuple-element(ta), index=0 + a.1 = s32[1] add(a, a) + tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1 + ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + } + Cond { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + b = s32[2] constant({0,1}) + c = s32[3] constant({0,1,2}) + d = s32[4] constant({0,1,2,3}) + ta = (s32[1]) tuple(a) + td = (s32[4]) tuple(d) + tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td) + init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + + const auto& instrs = m->entry_computation()->instructions(); + HloInstruction* new_while = + *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + Shape flat_tuple = + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") + .ValueOrDie(); + SCOPED_TRACE(m->ToString()); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") + .ValueOrDie())); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1f583ca44b7d20d56f27560f4a97a38c3fcc3026..039ccda7322f5efda6a827efbeda1225c3596cc0 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -270,4 +272,17 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { return result; } +/*static*/ absl::flat_hash_map> +WhileUtil::GetGTEsMapForWhileConditional( + const HloComputation& while_conditional) { + absl::flat_hash_map> result; + for (HloInstruction* user : + while_conditional.parameter_instruction(0)->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + result[user->tuple_index()].push_back(user); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 524dcec5f12689027ef76b8ae180bcbcc7cff601..cba41ccd8b184ba3d867bc170724aee71e777788 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -85,6 +87,13 @@ class WhileUtil { // Assumes `while_body` is the body computation of the while loop in question. static std::vector GetInvariantGTEsForWhileBody( const HloComputation& while_body); + + // Returns a map of index to GetTupleElement instructions in + // `while_conditional` that access elements in the parameter tuple. Assumes + // `while_conditional` is the conditional computation of the while loop in + // question. + static absl::flat_hash_map> + GetGTEsMapForWhileConditional(const HloComputation& while_conditional); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index b9ef18892d7aa859f6b0b505db4c004e4f5c5066..a546a6d39cc55d1f327b8449c7d26cd4c95dbf98 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -45,7 +45,8 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - auto module = CreateNewModule("zero_sized_elimination_test_module"); + auto module = + CreateNewUnverifiedModule("zero_sized_elimination_test_module"); module->AddEntryComputation(builder_.Build()); return ZeroSizedHloElimination{}.Run(module.get()); } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 14c35e7b84f07bebac33a9753ac26a8ee1418f1e..33edbd1b20d01bf132f2a152625d5f49a45f26f9 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,8 +47,11 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Compile(const CompileRequest* arg, + CompileResponse* result) = 0; + + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 17120e610cb26dda41fffd28fdb2b9e8bdffb973..d0c35d8dee46a1e0a5e343e0506a14ca1ce38bfd 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -74,6 +74,11 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { return out; } +bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { + return size() >= prefix.size() && + indices_.subspan(0, prefix.size()) == prefix.indices_; +} + namespace { // Returns whether the given primitive type corresponds to an array shape. diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 191ab04759f2d0ae87d988cba0d303f1ab696432..a7a3026cf3f3a53d34d389212738ca584a19db1d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -147,6 +147,9 @@ class ShapeIndexView { string ToString() const; + // Returns true if this shape index starts with 'prefix'. + bool StartsWith(ShapeIndexView prefix) const; + private: absl::Span indices_; }; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d395c9a4ceecfbd38076ac51f5a18da2ef098abb..db34d34f969311543d988ec6c3b8ee2af5b07e8e 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,7 +44,7 @@ cc_library( testonly = True, srcs = ["xla_internal_test_main.cc"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -117,12 +117,12 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", @@ -141,44 +141,6 @@ cc_library( ], ) -cc_library( - name = "hlo_verified_test_base", - testonly = True, - srcs = ["hlo_verified_test_base.cc"], - hdrs = ["hlo_verified_test_base.h"], - deps = [ - ":hlo_test_base", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "hlo_verified_test_base_test", - srcs = ["hlo_verified_test_base_test.cc"], - deps = [ - ":hlo_test_base", - ":hlo_verified_test_base", - ":test_macros_cpu", - ":test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -868,7 +830,8 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - shard_count = 25, + shard_count = 40, + tags = ["optonly"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 9966e4606ef7f104487182e0240e64e4c9e4d834..9930bfc95c297093584d427397cac042c296050f 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 3aebf784664dac14ba2ea45c5a229b7b2e4fc39d..211d004ec8c0a04b17c2454995880c0b565d3d4d 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -596,6 +596,272 @@ TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { this->RunTest(); } +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 1024}; + std::vector filter_dims = {3, 3, 1, 1024}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1024); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(4096, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { + this->RunTest(); +} + template class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { public: diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1407e68d9a336b6bb1c960711015430f872aa912..3622f2c1e84639baed13059b21b20609d1347da6 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); @@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, @@ -143,7 +143,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -175,7 +175,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -209,7 +209,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); Literal result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 001490c6a8c568656437465054ee4db40d0d8dee..738b6442354b01364278e3e3c713aa2cdb5cf47d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -70,7 +70,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -85,7 +85,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -106,7 +106,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -130,7 +130,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = @@ -155,7 +155,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { // The argument and result of the computation are set to different layouts, // but the custom call is layout constrained to a fixed operand and result // layout, so the correct result should be produced. - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 4d4b676a538947c8dd92a7e34db72e45766cae2c..d1fddf9d6b494a822610e41307fa103dc90bdef3 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -81,7 +81,7 @@ class FusionTest : public HloTestBase { } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -183,7 +183,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +231,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +266,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +290,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +314,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +344,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +359,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +374,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +389,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +404,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +419,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +434,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +449,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +465,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +483,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +501,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +519,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +541,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +559,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,7 +587,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -607,7 +607,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +630,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +682,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7ab2ecda58666acd7e9b8587d200a902b75822f3..d8fa00272f8f19ab843fd32a66fd6d6842997bdb 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -85,6 +85,23 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + } +} + HloTestBase::HloTestBase(bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, std::function @@ -100,17 +117,48 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool allow_mixed_precision_in_hlo_verifier, std::function instruction_can_change_layout_func) - : test_runner_(test_platform), reference_runner_(reference_platform) { + : test_runner_(test_platform), + reference_runner_(reference_platform), + verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, instruction_can_change_layout_func); } -std::unique_ptr HloTestBase::CreateNewModule(const string& name) { +std::unique_ptr HloTestBase::CreateNewUnverifiedModule( + const string& name) { return absl::make_unique(name, GetModuleConfigForTest()); } +std::unique_ptr HloTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + +StatusOr> +HloTestBase::ParseAndReturnUnverifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique(TestName(), config); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + return std::move(module); +} + +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + /* static */ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, HloModule* module) { @@ -135,7 +183,7 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { } DebugOptions HloTestBase::GetDebugOptionsForTest() { - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 217428befa474448cf2dcbae2eb6cb5b0e61d44c..366726d90b4752b6d53dc2133c8b0b5bbafce086 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -38,6 +38,31 @@ limitations under the License. namespace xla { +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and @@ -72,7 +97,27 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - std::unique_ptr CreateNewModule(const string& name = TestName()); + // + // This returns a vanilla HloModule that doesn't run the HLO verifier on + // destruction. + std::unique_ptr CreateNewUnverifiedModule( + const string& name = TestName()); + + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + // Parses the given string and returns module as a vanilla, unverified + // HloModule. + StatusOr> ParseAndReturnUnverifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -247,6 +292,8 @@ class HloTestBase : public ::testing::Test { HloRunner test_runner_; HloRunner reference_runner_; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc deleted file mode 100644 index 8bd0a729b77f3ec14204952cb0062103c823883e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - } -} - -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), - verifier_layout_sensitive_(layout_sensitive), - allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} - -HloModule& HloVerifiedTestBase::module() { - if (!module_) { - module_ = CreateNewVerifiedModule(TestName()); - } - return *module_; -} - -HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(CreateNewVerifiedModule(name)); - return modules_.back().get(); -} - -void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - CHECK(!module_) << "Called ParseModule when test already has a module."; - module_ = CreateNewVerifiedModule(TestName()); - TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); - module_->VerifyOrAddFailure("after parsing"); -} - -StatusOr> -HloVerifiedTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { - auto module = CreateNewVerifiedModule(TestName()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); - return std::move(module); -} - -std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( - const string& name) { - return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h deleted file mode 100644 index 388a99bb36408665edbc20ade6c6a733d64db88d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ - -#include -#include -#include - -#include "absl/base/macros.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - -// A base class for HLO tests that stores a default VerifiedHloModule. -class HloVerifiedTestBase : public HloTestBase { - protected: - HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - - // Constructs a default shape verifier. - std::unique_ptr MakeShapeVerifier(); - - // Returns the default HloModule, lazily creating it if necessary via - // HloTestBase::CreateNewModule(). - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule& module(); - - ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") - void ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Creates a new module for a test, and stores it in modules_ so it can be - // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent - // creation of unverified modules. - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule* CreateNewModule(const string& name = TestName()); - - // Creates and returns a verified HLO module with the given name. - std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); - - private: - // It is confusing to store modules created by module() and CreateNewModule() - // in different fields, but it allows us to migrate tests to - // HloVerifiedTestBase more easily, so it's a win because we can verify more - // modules. See b/80488902. - // - // Lazily populated. Access via module(). - std::unique_ptr module_; - - // Populated by calls to CreateNewModule. - std::vector> modules_; - - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc deleted file mode 100644 index 5c0263e811f94c90a69a460525ffa0c65127ebb5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// This class includes unit tests which are expected to fail because invalid HLO -// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to -// include the necessary gunit parts to test this test machinery (needs the -// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the -// disabled tests enabled and failures can be manually compared against -// expectations. -class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; - -XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { - // Test shouldn't fail if no module is created at all. -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { - // Use module() to lazily create an empty module, build it up, and verify no - // failures. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { - // Use module() to lazily create an empty module and build up an invalid - // module. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); - - *hlo_module.entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { - // Call CreateNewModule and build up a valid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { - // Call CreateNewModule and build up a invalid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); - - *module->entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndVerifyModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - ParseAndVerifyModule(hlo_string); - EXPECT_EQ(module().entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_EQ(module->entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} - -RANDOM GARBAGE -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleBad - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[1234] add(x,y) -} -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index c622b295094e53e63d0ed692d428bc97724c787c..a78ccacec114858740bf1b9c04e9b688bca5818d 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -68,7 +68,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); compiler->SetPreOptimizationHook(pre_opt_hook); @@ -90,7 +90,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - std::unique_ptr hlo_module = CreateNewModule(); + std::unique_ptr hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto module_group = absl::make_unique("test_module_group"); @@ -124,9 +124,9 @@ class LLVMCompilerTest : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - static std::unique_ptr CreateNewModule() { + static std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index ca7637a0cfa5d837dfd86aadafd1e5cc19ffc22e..3f5135438fc59bea98527b1be30ee49339edd455 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = @@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48..3fb69419e735bfd9c5054673e0687f5139a410cb 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -166,6 +166,26 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +TEST_F(SliceTest, SliceOfReshape) { + Array2D values(2 * 3 * 24, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR2FromArray2D(&builder, values); + auto reshape = Reshape(original, {24, 3, 2, 7}); + Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1}); + ComputeAndCompare(&builder, {}); +} + +TEST_F(SliceTest, SliceOfCollapsingReshape) { + Array4D values(2, 3, 5, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR4FromArray4D(&builder, values); + auto reshape = Reshape(original, {2 * 3 * 5, 7}); + Slice(reshape, {0, 0}, {4, 7}, {1, 1}); + ComputeAndCompare(&builder, {}); +} + XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { Array4D values(2, 4, 6, 8); values.FillRandom(3.14f); @@ -253,7 +273,6 @@ XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } - // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b34fd0f2e873214c509533f29553af914ddc984d..a2b7c26331b3cc89ed0413efe8eb31c2b9e37038 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -28,7 +28,7 @@ namespace { class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateToken()); @@ -38,8 +38,22 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } +XLA_TEST_F(TokenHloTest, TokenInTuple) { + std::unique_ptr module = CreateNewUnverifiedModule(); + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateTuple({token})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + Literal token_literal = LiteralUtil::CreateToken(); + EXPECT_TRUE( + LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal}))); +} + XLA_TEST_F(TokenHloTest, TokenTree) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); @@ -54,7 +68,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); @@ -75,7 +89,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateParameter( 0, @@ -95,7 +109,7 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 376559500efad6a756f8a0f60f0a522db047c0e5..ca036f1ae0d5e31a3f83d9d31c80e070c2a666df 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -91,8 +91,8 @@ Status ParseOneProfileOutputLine( string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; - string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; - string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?"; + string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?"; // The underlined part is what we're trying to match with match_opcode: // @@ -307,6 +307,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, matrix_shape); + SCOPED_TRACE(profile_output); std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); @@ -318,14 +319,13 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = std::find_if( - while_body_profile_start, profile_output_lines.end(), - [](absl::string_view s) { - return absl::StartsWith(s, "********** microseconds report **********"); - }); + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds "); + }); - // We emit a blank line before the "********** microseconds report **********" - // line. + // We emit a blank line before the "microseconds report" line. while_body_profile_end--; ASSERT_NE(while_body_profile_end, profile_output_lines.end()); @@ -380,7 +380,7 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 15603619b62d8f45cdce97ac7d83924a78f88cf3..dca0aa52a533130372759156a3238f1a3b10ca42 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -15,14 +15,14 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { LOG(ERROR) << "\n" << usage; @@ -49,7 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { // different API than Tensorflow's. testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) - base::SetFlag(&FLAGS_benchmarks, pattern); + absl::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); #else tensorflow::testing::Benchmark::Run(pattern); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 3a086c66bbb37965b1ad7c83a93f0054ae723e87..8926bbed2b54fceaaf0e6e991f0e881d35731ef4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -33,6 +33,7 @@ cc_library( name = "dumped_computation_to_graphviz_library", srcs = ["dumped_computation_to_graphviz.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -40,7 +41,6 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -78,6 +78,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -91,7 +92,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -207,13 +207,13 @@ tf_cc_binary( name = "dumped_computation_to_tf_graphdef", srcs = ["dumped_computation_to_tf_graphdef.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index c866a13de7543fc948311f94708bc6b904717b62..b623556468fb4a5d96be614b6c067d5a1df51a6f 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,7 +54,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = client->GetComputationStats(computation, debug_options) @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 07ef5ff656bb48519a700a1d7d6c60b655a40ed6..f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,7 +53,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 109411f99b6eb000474b0c61783c51f42d43bb6d..47be9f5adf1063463d7678579a7f394684aaf357 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -47,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -191,8 +191,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, // Run the computation num_runs times, and return the result from the last // execution. - const bool xla_hlo_profile = - legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 65948ef4b0c3d51805b15634e6215f192e740aaa..28df3b03f398841460189910bc3a5096dfb0d367 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -322,6 +322,34 @@ message UnregisterRequest { message UnregisterResponse { } +message CompileRequest { + // The graph to be compiled. + HloModuleProto computation = 1; + + // Options that affect how XLA compiles code to service this request. + ExecutionOptions execution_options = 2; + + // The layouts of the input arguments. If not set, the default layout will be + // used. Although the real arguments are not needed in compilation, the + // layouts of the arguments can affect the compilation. + repeated Shape input_shape_with_layout = 3; +} + +message CompileResponse { + // The handle to the executable. + ExecutionHandle handle = 1; +} + +message ExecuteRequest { + ExecutionHandle handle = 1; + + // The shape and layout of the arguments must be the same as the those of the + // executable's parameters. + repeated GlobalDataHandle arguments = 2; +} + +// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace +// the uses with calls to Compile and Execute. message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b6bd919e2b26a109cb9dfd2a6aaba86f1732cff1..683ccc40f162ead3a248aee83d9abf3086a1ac93 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -332,11 +332,13 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - // The F16s and BF16s are encoded in little endian byte order + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 16 + // Next = 18 } message WindowDimension { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 9e3d2454d16730c1d1f93cb384db88544380f77e..67f475846e5f16060c1080759b0acb4216c4e72b 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["xrt_state_ops.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -21,7 +22,6 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 5678f0905ff5b8956e0811026e7450acba8815e9..6ab77fbaaf0cbe23503ebc71775f52af01e41a74 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -6,6 +6,24 @@ import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; +message DeviceAssignment { + message ComputationDevice { + message DeviceMeshCoordinates { + // The mesh coordinates for the device. Usually (X, Y, Core), in the order + // in which they are returned in the TopologyProto. + // X = value(0) + // Y = value(1) + // Core = value(2) + repeated int32 value = 1; + } + // As many replicas as there are in the replicated computation. + repeated DeviceMeshCoordinates replica_devices = 1; + } + // As many ComputationDevice as many there are computations (number + // of cores per replica). + repeated ComputationDevice computation_devices = 1; +} + // Options for an XLA compilation. message XLAComputationConfig { // The number of replicas the computation will be run on. If this is @@ -23,6 +41,11 @@ message XLAComputationConfig { // computation. per_core_args_and_result_shapes is optional for a // single-core computation. repeated xla.ProgramShape per_core_program_shape = 5; + // Describes how replicated computation instances should be assigned to + // devices. There are num_cores_per_replica computations, and each one will be + // sent and executed to the set of replica device numbers described in the + // DeviceAssignment proto. + DeviceAssignment device_assignment = 6; } // Options and XLA computation for a compilation. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f45010ec26ed25127ca78b97f4d6fd7ebd6467ae..1fffbb5f660c681e1dde11a2aaf1d0f1cf79d1d0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -142,7 +142,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): name="StatsAccumulator/{}".format(self._name)) # Allocate both stats accumulator and quantile accumulator on the same # device so that we can build splits with fewer RPCs. - with ops.colocate_with(self._stats_accumulator.resource()): + with ops.colocate_with(self._stats_accumulator.resource_handle): self._quantile_accumulator = quantile_ops.QuantileAccumulator( init_stamp_token, epsilon=epsilon, @@ -268,8 +268,8 @@ class DenseSplitHandler(InequalitySplitHandler): handler = make_dense_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -447,8 +447,8 @@ class SparseSplitHandler(InequalitySplitHandler): handler = make_sparse_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 05ce0884ccfff53484fdc0c26e596e7fb6fcdfd6..356ae337685d580319da16a20bbab27ccaa73255 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -34,7 +34,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -62,7 +62,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2, 1], @@ -91,7 +91,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -123,7 +123,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -133,7 +133,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -164,7 +164,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -175,7 +175,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): deserialize = ( - accumulator.deserialize( + accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], @@ -223,7 +223,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -261,7 +261,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -299,7 +299,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -336,7 +336,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -349,7 +349,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -386,7 +386,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -399,7 +399,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): 0.08]]]) with ops.control_dependencies([op1]): - deserialize = accumulator.deserialize( + deserialize = accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 25b2c9e2fd72bd018717e8a87fce726f26bad968..fca22c71a83459cb290eaebcf107cf1c14c222b7 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import @@ -31,6 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -82,6 +85,44 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): tree_ensemble_config=restored_tensors[1]) +class TreeEnsembleVariable(tracking.TrackableResource): + """A Tree ensemble model.""" + + def __init__(self, stamp_token, tree_ensemble_config, name, container=None): + self._stamp_token = stamp_token + self._tree_ensemble_config = tree_ensemble_config + self._name = name + self._container = container + self._init_op = None + super(TreeEnsembleVariable, self).__init__() + + def create_resource(self): + return gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_ensemble_variable( + self.resource_handle, self._stamp_token, self._tree_ensemble_config) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_ensemble_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return { + "tree_ensemble_variable": + functools.partial( + TreeEnsembleVariableSavable, + tree_ensemble_handle=self.resource_handle, + create_op=self.initializer) + } + + def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, @@ -99,12 +140,11 @@ def tree_ensemble_variable(stamp_token, A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: - resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( - container, shared_name=name, name=name) - create_op = gen_model_ops.create_tree_ensemble_variable( - resource_handle, stamp_token, tree_ensemble_config) - is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( - resource_handle) + tree_ensemble_var = TreeEnsembleVariable(stamp_token, tree_ensemble_config, + name, container) + resource_handle = tree_ensemble_var.resource_handle + create_op = tree_ensemble_var.initializer + is_initialized_op = tree_ensemble_var.is_initialized() # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 19b6b3296db394b07f57a25dbde187eb9195af38..0c319cc9bd1f720eb404a9da05227c5807ec874f 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,59 +33,20 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): - """A resource that allows distributed quantile computation.""" - - def __init__(self, - init_stamp_token, - epsilon, - num_quantiles, - max_elements=None, - name=None, - container=None, - generate_quantiles=False): - """Creates a QuantileAccumulator object. - - Args: - init_stamp_token: The initial value for the stamp token. - epsilon: Error bound on the quantile computation. - num_quantiles: Number of quantiles to produce from the final summary. - max_elements: Maximum number of elements added to the accumulator. - name: the name to save the accumulator under. - container: An optional `string`. Defaults to `""` - generate_quantiles: Generate quantiles instead of approximate boundaries. - If true, exactly `num_quantiles` will be produced in the final summary. - """ - self._epsilon = epsilon - self._generate_quantiles = generate_quantiles +class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for QuantileAccumulator.""" - name = _PATTERN.sub("", name) - with ops.name_scope(name, "QuantileAccumulator") as name: - self._quantile_accumulator_handle = ( - gen_quantile_ops.quantile_stream_resource_handle_op( - container=container, shared_name=name, name=name)) - self._create_op = gen_quantile_ops.create_quantile_accumulator( - self._quantile_accumulator_handle, - init_stamp_token, - epsilon=epsilon, - max_elements=max_elements, - num_quantiles=num_quantiles, - generate_quantiles=generate_quantiles) - is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( - self._quantile_accumulator_handle) - resources.register_resource(self._quantile_accumulator_handle, - self._create_op, is_initialized_op) - self._make_savable(name) - - def _make_savable(self, name): + def __init__(self, resource_handle, create_op, name): + self._resource_handle = resource_handle + self._create_op = create_op stamp_token, state, are_buckets_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_serialize( - self._quantile_accumulator_handle)) + gen_quantile_ops.quantile_accumulator_serialize(resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful in quantile accumulator. slice_spec = "" @@ -96,9 +57,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): specs += [make_save_spec(state, "_state")] specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] specs += [make_save_spec(buckets, "buckets")] - super(QuantileAccumulator, - self).__init__(self._quantile_accumulator_handle, specs, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, + specs, name) def restore(self, restored_tensors, unused_restored_shapes): """Restores the associated quantile accumulator from 'restored_tensors'. @@ -119,24 +79,94 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): buckets = restored_tensors[3] with ops.control_dependencies([self._create_op]): return gen_quantile_ops.quantile_accumulator_deserialize( - self._quantile_accumulator_handle, + self._resource_handle, stamp_token=stamp_token, stream_state=state, are_buckets_ready=are_buckets_ready, buckets=buckets) + +class QuantileAccumulator(tracking.TrackableResource): + """A resource that allows distributed quantile computation.""" + + def __init__(self, + init_stamp_token, + epsilon, + num_quantiles, + max_elements=None, + name=None, + container=None, + generate_quantiles=False): + """Creates a QuantileAccumulator object. + + Args: + init_stamp_token: The initial value for the stamp token. + epsilon: Error bound on the quantile computation. + num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. + name: the name to save the accumulator under. + container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. + """ + self._init_stamp_token = init_stamp_token + self._epsilon = epsilon + self._num_quantiles = num_quantiles + self._max_elements = max_elements + self._container = container + self._generate_quantiles = generate_quantiles + super(QuantileAccumulator, self).__init__() + + name = _PATTERN.sub("", name) + with ops.name_scope(name, "QuantileAccumulator") as name: + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self._init_op, + is_initialized_op) + self._saveable = QuantileAccumulatorSaveable(self.resource_handle, + self._init_op, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + def create_resource(self): + return gen_quantile_ops.quantile_stream_resource_handle_op( + container=self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_quantile_ops.create_quantile_accumulator( + self.resource_handle, + self._init_stamp_token, + epsilon=self._epsilon, + max_elements=self._max_elements, + num_quantiles=self._num_quantiles, + generate_quantiles=self._generate_quantiles) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_quantile_ops.quantile_accumulator_is_initialized( + self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return {"quantile_accumulator", self.saveable} + def get_buckets(self, stamp_token): """Returns quantile buckets created during previous flush.""" are_buckets_ready, buckets = ( gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token)) return are_buckets_ready[0], buckets[0] def schedule_get_buckets(self): """Returns a scheduled read of buckets created during previous flush.""" return batch_ops_utils.ScheduledStampedResourceOp( - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, op=gen_quantile_ops.quantile_accumulator_get_buckets) def _make_summary(self, column, example_weights): @@ -161,14 +191,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Adds quantile summary to its stream in resource.""" summary = self._make_summary(column, example_weights) return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) def add_prebuilt_summary(self, stamp_token, summary): """Adds quantile summary to its stream in resource.""" return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) @@ -177,7 +207,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summary = self._make_summary(column, example_weights) return batch_ops_utils.ScheduledStampedResourceOp( op=gen_quantile_ops.quantile_accumulator_add_summaries, - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, summaries=summary) def flush(self, stamp_token, next_stamp_token): @@ -190,17 +220,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) def flush_summary(self, stamp_token, next_stamp_token): """Finalizes quantile summary stream and resets it for next iteration.""" result = gen_quantile_ops.quantile_accumulator_flush_summary( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result - - def resource(self): - return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 2e94e353f325f06eed2d290d3a7a461861820c39..ad1191d41236e71008bff8c8a7fbd42c16e3f9c5 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,12 +26,83 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): +class StatsAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for StatsAccumulator.""" + + def __init__(self, resource_handle, create_op, is_scalar, name): + self._create_op = create_op + self._resource_handle = resource_handle + self._is_scalar = is_scalar + slice_spec = "" + saver_name = self._resource_handle.name + (stamp_token, num_updates, partition_ids, feature_ids, gradients, + hessians) = self.serialize() + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + saver_name + "_stamp"), + saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, + saver_name + "_num_updates"), + saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, + saver_name + "_partition_ids"), + saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, + saver_name + "_feature_ids"), + saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, + saver_name + "_gradients"), + saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, + saver_name + "hessians"), + ] + super(StatsAccumulatorSaveable, self).__init__(self._resource_handle, specs, + name) + + def serialize(self): + """Serializes the stats accumulator state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( + self._resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( + self._resource_handle) + + def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, + gradients, hessians): + """Resets the stats accumulator with the serialized state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return self.deserialize( + stamp_token=restored_tensors[0], + num_updates=restored_tensors[1], + partition_ids=restored_tensors[2], + feature_ids=restored_tensors[3], + gradients=restored_tensors[4], + hessians=restored_tensors[5]) + + +class StatsAccumulator(tracking.TrackableResource): """A resource that allows to accumulate gradients and hessians. For consistency guarantees, we use read and write stamp tokens. @@ -58,58 +129,69 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ + self._stamp_token = stamp_token + self._gradient_shape = gradient_shape + self._hessian_shape = hessian_shape + self._container = container + + if (gradient_shape == tensor_shape.scalar() and + hessian_shape == tensor_shape.scalar()): + self._is_scalar = True + else: + self._is_scalar = False + if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: - # Both values are scalars. - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): - self._is_scalar = True - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_scalar_resource_handle_op( - container, name, name=name)) - - create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( - self._resource_handle, stamp_token) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( - self._resource_handle)) - else: - self._is_scalar = False - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_tensor_resource_handle_op( - container, name, name=name)) - create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( - self._resource_handle, stamp_token, gradient_shape.as_list(), - hessian_shape.as_list()) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( - self._resource_handle)) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self.initializer, + is_initialized_op) + self._saveable = StatsAccumulatorSaveable( + self.resource_handle, self.initializer, self._is_scalar, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - self._create_op = create_op - slice_spec = "" - saver_name = self._resource_handle.name - (stamp_token, num_updates, partition_ids, feature_ids, gradients, - hessians) = self.serialize() - specs = [ - saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, - saver_name + "_stamp"), - saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, - saver_name + "_num_updates"), - saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, - saver_name + "_partition_ids"), - saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, - saver_name + "_feature_ids"), - saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, - saver_name + "_gradients"), - saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, - saver_name + "hessians"), - ] + def create_resource(self): + if self._is_scalar: + return ( + gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( + self._container, self._name, name=self._name)) + else: + return ( + gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( + self._container, self._name, name=self._name)) - super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) - resources.register_resource(self._resource_handle, create_op, - is_initialized_op) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + def initialize(self): + if self._is_scalar: + return gen_stats_accumulator_ops.create_stats_accumulator_scalar( + self.resource_handle, self._stamp_token) + else: + return gen_stats_accumulator_ops.create_stats_accumulator_tensor( + self.resource_handle, self._stamp_token, + self._gradient_shape.as_list(), self._hessian_shape.as_list()) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( + self.resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( + self.resource_handle) + + @property + def saveable(self): + return self._saveable + + def _gather_saveables_for_checkpoint(self): + return {"stats_accumulator", self.saveable} def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): """Updates the stats accumulator.""" @@ -117,11 +199,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): partition_ids, feature_ids, gradients, hessians)) if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) def schedule_add(self, partition_ids, feature_ids, gradients, hessians): @@ -131,7 +213,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): if self._is_scalar: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_scalar_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -139,7 +221,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): else: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_tensor_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -153,55 +235,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): return gen_stats_accumulator_ops.stats_accumulator_tensor_make_summary( partition_ids, feature_ids, gradients, hessians) - def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, - gradients, hessians): - """Resets the stats accumulator with the serialized state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - def flush(self, stamp_token, next_stamp_token): """Flushes the stats accumulator.""" if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - self._resource_handle, stamp_token, next_stamp_token) + self.resource_handle, stamp_token, next_stamp_token) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - self._resource_handle, stamp_token, next_stamp_token) - - def serialize(self): - """Serializes the stats accumulator state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( - self._resource_handle) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( - self._resource_handle) - - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated tree ensemble from 'restored_tensors'. - - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. Not meaningful for trees. - - Returns: - The operation that restores the state of the tree ensemble variable. - """ - with ops.control_dependencies([self._create_op]): - return self.deserialize( - stamp_token=restored_tensors[0], - num_updates=restored_tensors[1], - partition_ids=restored_tensors[2], - feature_ids=restored_tensors[3], - gradients=restored_tensors[4], - hessians=restored_tensors[5]) - - def resource(self): - return self._resource_handle + self.resource_handle, stamp_token, next_stamp_token) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 1cf61a10ba25f206333bb78b7944e366bcd19b92..ab5713fbe26ab76eac923035e9feecc2ec51f492 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -992,7 +992,7 @@ class GradientBoostedDecisionTreeModel(object): # Get accumulated steps and examples for the current layer. _, _, _, _, acc_examples, acc_steps = ( - steps_accumulator.serialize()) + steps_accumulator.saveable.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) ensemble_update_ops.append( diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 5ecd4f341831ce8d6f8eb04a763280c177ffe275..40b1e667ee6039b44b1a442d41dc28dfcbad6dc6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -25,6 +25,13 @@ import six from tensorflow.python.training.server_lib import ClusterSpec +def _format_master_url(master, rpc_layer=None): + if rpc_layer: + return '%s://%s' % (rpc_layer, master) + else: + return master + + @six.add_metaclass(abc.ABCMeta) class ClusterResolver(object): """Abstract class for all implementations of ClusterResolvers. @@ -57,12 +64,13 @@ class ClusterResolver(object): 'cluster_spec is not implemented for {}.'.format(self)) @abc.abstractmethod - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Retrieves the name or URL of the session master. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. @@ -77,10 +85,18 @@ class ClusterResolver(object): class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec, master=''): + def __init__(self, cluster_spec, master='', task_type=None, task_index=None, + environment='', num_accelerators_per_worker=0, + rpc_layer=None): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() + self._task_type = task_type + self._task_index = task_index + self._environment = environment + self._num_accelerators_per_worker = num_accelerators_per_worker + self._rpc_layer = rpc_layer + if not isinstance(cluster_spec, ClusterSpec): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec @@ -93,12 +109,13 @@ class SimpleClusterResolver(ClusterResolver): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC used by distributed TensorFlow. Returns: The name or URL of the session master. @@ -106,10 +123,52 @@ class SimpleClusterResolver(ClusterResolver): If a task_type and task_index is given, this will override the `master` string passed into the initialization function. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + else: + master = self._master + + return _format_master_url(master, rpc_layer or self._rpc_layer) + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._environment + + def num_accelerators_per_worker(self, session_config=None): + """Returns the number of accelerator cores per worker. + + Args: + session_config: Unused. The SimpleClusterResolver does not do automatic + detection of accelerators, so a TensorFlow session will never be + created, and thus a `session_config` is never necessary here, and will + be ignored. + """ + del session_config + return self._num_accelerators_per_worker - return self._master + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer class UnionClusterResolver(ClusterResolver): @@ -119,13 +178,22 @@ class UnionClusterResolver(ClusterResolver): merges the underlying ClusterResolvers, and returns one unified ClusterSpec when cluster_spec is called. The details of the merge function is documented in the cluster_spec function. + + For additional Cluster Resolver properties such as task type, task index, + rpc layer, environment, etc..., we will return the value from the first + ClusterResolver in the union. """ - def __init__(self, *args): + def __init__(self, *args, **kwargs): """Initializes a UnionClusterResolver with other ClusterResolvers. Args: *args: `ClusterResolver` objects to be unionized. + **kwargs: + rpc_layer - (Optional) Override value for the RPC layer used by + TensorFlow. + task_type - (Optional) Override value for the current task type. + task_index - (Optional) Override value for the current task index. Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. @@ -133,6 +201,13 @@ class UnionClusterResolver(ClusterResolver): """ super(UnionClusterResolver, self).__init__() + self._rpc_layer = kwargs.pop('rpc_layer', None) + self._task_type = kwargs.pop('task_type', None) + self._task_index = kwargs.pop('task_index', None) + + if kwargs: + raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) + if not args: raise ValueError('At least one ClusterResolver is required.') @@ -216,7 +291,7 @@ class UnionClusterResolver(ClusterResolver): return ClusterSpec(merged_cluster) - def master(self, task_type=None, task_index=None): + def master(self, task_type=None, task_index=None, rpc_layer=None): """Returns the master address to use when creating a session. This usually returns the master from the first ClusterResolver passed in, @@ -225,11 +300,45 @@ class UnionClusterResolver(ClusterResolver): Args: task_type: (Optional) The type of the TensorFlow task of the master. task_index: (Optional) The index of the TensorFlow task of the master. + rpc_layer: (Optional) The RPC protocol for the given cluster. Returns: The name or URL of the session master. """ - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + return _format_master_url(master, rpc_layer or self._rpc_layer) + + return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) + + @property + def task_type(self): + return self._task_type or self._cluster_resolvers[0].task_type + + @property + def task_index(self): + return self._task_index or self._cluster_resolvers[0].task_index + + @task_type.setter + def task_type(self, task_type): + self._task_type = task_type + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + return self._cluster_resolvers[0].environment + + def num_accelerators_per_worker(self, session_config=None): + return self._cluster_resolvers[0].num_accelerators_per_worker( + session_config) + + @property + def rpc_layer(self): + return self._rpc_layer or self._cluster_resolvers[0].rpc_layer - return self._cluster_resolvers[0].master() + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index c004b2e2d3bc6552a3ab10997ed44f24e611735a..b94c9612b5bd4d92e84319f22932ce5599ba4b36 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -57,6 +57,62 @@ class UnionClusterResolverTest(test.TestCase): actual_cluster_spec = union_resolver.cluster_spec() self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testInitSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + self.assertEqual(simple_resolver.task_type, "ps") + self.assertEqual(simple_resolver.task_index, 1) + self.assertEqual(simple_resolver.environment, "cloud") + self.assertEqual(simple_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(simple_resolver.rpc_layer, "grpc") + + def testOverrideSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + simple_resolver.task_type = "worker" + simple_resolver.task_index = 2 + simple_resolver.rpc_layer = "http" + + self.assertEqual(simple_resolver.task_type, "worker") + self.assertEqual(simple_resolver.task_index, 2) + self.assertEqual(simple_resolver.rpc_layer, "http") + + def testSimpleOverrideMasterWithTaskIndexZero(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 0, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker0:2222") + + def testSimpleOverrideMasterWithRpcLayer(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + + simple_resolver = SimpleClusterResolver(base_cluster_spec) + actual_master = simple_resolver.master("worker", 2, rpc_layer="grpc") + self.assertEqual(actual_master, "grpc://worker2:2222") + def testSimpleOverrideMaster(self): base_cluster_spec = server_lib.ClusterSpec({ "ps": ["ps0:2222", "ps1:2222"], @@ -65,7 +121,42 @@ class UnionClusterResolverTest(test.TestCase): simple_resolver = SimpleClusterResolver(base_cluster_spec) actual_master = simple_resolver.master("worker", 2) - self.assertEquals(actual_master, "worker2:2222") + self.assertEqual(actual_master, "worker2:2222") + + def testUnionClusterResolverGetProperties(self): + cluster_spec_1 = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + resolver1 = SimpleClusterResolver(cluster_spec_1, task_type="ps", + task_index=1, environment="cloud", + num_accelerators_per_worker=8, + rpc_layer="grpc") + + cluster_spec_2 = server_lib.ClusterSpec({ + "ps": ["ps2:2222", "ps3:2222"], + "worker": ["worker3:2222", "worker4:2222", "worker5:2222"] + }) + resolver2 = SimpleClusterResolver(cluster_spec_2, task_type="worker", + task_index=2, environment="local", + num_accelerators_per_worker=16, + rpc_layer="http") + + union_resolver = UnionClusterResolver(resolver1, resolver2) + + self.assertEqual(union_resolver.task_type, "ps") + self.assertEqual(union_resolver.task_index, 1) + self.assertEqual(union_resolver.environment, "cloud") + self.assertEqual(union_resolver.num_accelerators_per_worker(), 8) + self.assertEqual(union_resolver.rpc_layer, "grpc") + + union_resolver.task_type = "worker" + union_resolver.task_index = 2 + union_resolver.rpc_layer = "http" + + self.assertEqual(union_resolver.task_type, "worker") + self.assertEqual(union_resolver.task_index, 2) + self.assertEqual(union_resolver.rpc_layer, "http") def testTwoNonOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ @@ -116,10 +207,13 @@ class UnionClusterResolverTest(test.TestCase): union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) unspecified_master = union_cluster.master() - self.assertEquals(unspecified_master, "") + self.assertEqual(unspecified_master, "") specified_master = union_cluster.master("worker", 1) - self.assertEquals(specified_master, "worker1:2222") + self.assertEqual(specified_master, "worker1:2222") + + rpc_master = union_cluster.master("worker", 1, rpc_layer="grpc") + self.assertEqual(rpc_master, "grpc://worker1:2222") def testOverlappingJobMergedClusterResolver(self): cluster_spec_1 = server_lib.ClusterSpec({ diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 5083e4d10ba6ee2e1be8f373c099556b422ef5aa..195b68959b6d21ef674438a4a23a4dd07f45faa7 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -30,6 +30,10 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +def _format_master_url(master, rpc_layer=None): + return '%s://%s' % (rpc_layer, master) if rpc_layer else master + + class GceClusterResolver(ClusterResolver): """Cluster Resolver for Google Compute Engine. @@ -45,7 +49,10 @@ class GceClusterResolver(ClusterResolver): zone, instance_group, port, - job_name='worker', + task_type='worker', + task_index=0, + rpc_layer='grpc', + num_accelerators_per_worker=0, credentials='default', service=None): """Creates a new GceClusterResolver object. @@ -55,13 +62,22 @@ class GceClusterResolver(ClusterResolver): each instance in the instance group. Args: - project: Name of the GCE project - zone: Zone of the GCE instance group - instance_group: Name of the GCE instance group + project: Name of the GCE project. + zone: Zone of the GCE instance group. + instance_group: Name of the GCE instance group. port: Port of the listening TensorFlow server (default: 8470) - job_name: Name of the TensorFlow job this set of instances belongs to + task_type: Name of the TensorFlow job this GCE instance group of VM + instances belong to. + task_index: The task index for this particular VM, within the GCE + instance group. In particular, every single instance should be assigned + a unique ordinal index within an instance group manually so that they + can be distinguished from each other. + rpc_layer: The RPC layer TensorFlow should use to communicate across + instances. + num_accelerators_per_worker: Number of accelerators (GPUs) present per + instance. credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default() + GoogleCredentials.get_application_default(). service: The GCE API object returned by the googleapiclient.discovery function. (Default: discovery.build('compute', 'v1')). If you specify a custom service object, then the credentials parameter will be ignored. @@ -72,7 +88,9 @@ class GceClusterResolver(ClusterResolver): self._project = project self._zone = zone self._instance_group = instance_group - self._job_name = job_name + self._task_type = task_type + self._task_index = task_index + self._rpc_layer = rpc_layer self._port = port self._credentials = credentials @@ -133,10 +151,58 @@ class GceClusterResolver(ClusterResolver): previous_response=response) worker_list.sort() - return ClusterSpec({self._job_name: worker_list}) + return ClusterSpec({self._task_type: worker_list}) - def master(self, task_type=None, task_index=None): - if task_type and task_index: - return self.cluster_spec().task_address(task_type, task_index) + def master(self, task_type=None, task_index=None, rpc_layer=None): + task_type = task_type if task_type is not None else self._task_type + task_index = task_index if task_index is not None else self._task_index + + if task_type is not None and task_index is not None: + master = self.cluster_spec().task_address(task_type, task_index) + if rpc_layer or self._rpc_layer: + return '%s://%s' % (rpc_layer or self._rpc_layer, master) + else: + return master return '' + + @property + def task_type(self): + return self._task_type + + @property + def task_index(self): + return self._task_index + + @task_type.setter + def task_type(self, task_type): + raise RuntimeError( + 'You cannot reset the task_type of the GceClusterResolver after it has ' + 'been created.') + + @task_index.setter + def task_index(self, task_index): + self._task_index = task_index + + @property + def environment(self): + """Returns the current environment which TensorFlow is running in. + + For users in the GCE environment, the environment property is always an + empty string, and Google users will not use this ClusterResolver for running + on internal systems. + """ + return '' + + @property + def rpc_layer(self): + return self._rpc_layer + + @rpc_layer.setter + def rpc_layer(self, rpc_layer): + self._rpc_layer = rpc_layer + + def num_accelerators_per_worker(self, session_config=None): + del session_config # Unused, since this is set manually in __init__. + return self._num_accelerators_per_worker + diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py index 87b8303122498992dd24ae06824f7f769357d8f8..c691552e86025896e23891a3e8f7da5ed2f9da31 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver_test.py @@ -135,12 +135,86 @@ class GceClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testMasterRetrieval(self): + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_index=0, + port=8470, + credentials=None, + service=self.standard_mock_service_client()) + self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470') + + def testMasterRetrievalWithCustomTasks(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual( + gce_cluster_resolver.master('worker', 2, 'test'), + 'test://10.3.4.5:8470') + + def testOverrideParameters(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='testworker', + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + gce_cluster_resolver.task_index = 1 + gce_cluster_resolver.rpc_layer = 'test' + + self.assertEqual(gce_cluster_resolver.task_type, 'testworker') + self.assertEqual(gce_cluster_resolver.task_index, 1) + self.assertEqual(gce_cluster_resolver.rpc_layer, 'test') + self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470') + + def testOverrideParametersWithZeroOrEmpty(self): + name_to_ip = [ + {'name': 'instance1', 'ip': '10.1.2.3'}, + {'name': 'instance2', 'ip': '10.2.3.4'}, + {'name': 'instance3', 'ip': '10.3.4.5'}, + ] + + gce_cluster_resolver = GceClusterResolver( + project='test-project', + zone='us-east1-d', + instance_group='test-instance-group', + task_type='', + task_index=1, + port=8470, + credentials=None, + service=self.gen_standard_mock_service_client(name_to_ip)) + + self.assertEqual(gce_cluster_resolver.master( + task_type='', task_index=0), 'grpc://10.1.2.3:8470') + def testCustomJobNameAndPortRetrieval(self): gce_cluster_resolver = GceClusterResolver( project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='custom', + task_type='custom', port=2222, credentials=None, service=self.standard_mock_service_client()) @@ -196,7 +270,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker1_name_to_ip)) @@ -205,7 +279,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='worker', + task_type='worker', port=8470, credentials=None, service=self.gen_standard_mock_service_client(worker2_name_to_ip)) @@ -214,7 +288,7 @@ class GceClusterResolverTest(test.TestCase): project='test-project', zone='us-east1-d', instance_group='test-instance-group', - job_name='ps', + task_type='ps', port=2222, credentials=None, service=self.gen_standard_mock_service_client(ps_name_to_ip)) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c4ac9d0700194da558820aabc28bf1c0857591e2..1f6803a9ff9a7a1e72ee691afd7e22bb4d85475c 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -50,6 +50,34 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _tpuService(self): + """Creates a new Cloud TPU API object. + + This works around an issue where the underlying HTTP connection sometimes + times out when the script has been running for too long. Other methods in + this object calls this method to get a new API object whenever they need + to communicate with the Cloud API. + + Returns: + A Google Cloud TPU API object. + """ + if self._service: + return self._service + + credentials = self._credentials + if credentials is None or credentials == 'default': + credentials = GoogleCredentials.get_application_default() + + if self._discovery_url: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials, + discoveryServiceUrl=self._discovery_url) + else: + return discovery.build( + 'tpu', 'v1alpha1', + credentials=credentials) + def _requestComputeMetadata(self, path): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) @@ -81,7 +109,7 @@ class TPUClusterResolver(ClusterResolver): return None @staticmethod - def _discoveryUrl(): + def _environmentDiscoveryUrl(): return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) def __init__(self, @@ -154,49 +182,42 @@ class TPUClusterResolver(ClusterResolver): self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name - self._credentials = credentials + # Whether we should actually attempt to contact Cloud APIs should_resolve = self._shouldResolve() + # We error out if we are in a non-Cloud environment which cannot talk to the + # Cloud APIs using the standard class and a special object is not passed in. + self._service = service + if (self._service is None and should_resolve and + not _GOOGLE_API_CLIENT_INSTALLED): + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + # We save user-passed credentials, unless the user didn't pass in anything. + self._credentials = credentials + if (credentials == 'default' and should_resolve and + _GOOGLE_API_CLIENT_INSTALLED): + self._credentials = None + + # Automatically detect project and zone if unspecified. if not project and should_resolve: project = compat.as_str( self._requestComputeMetadata('project/project-id')) - if not zone and should_resolve: zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] - self._project = project self._zone = zone - if credentials == 'default' and should_resolve: - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None and should_resolve: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - final_discovery_url = self._discoveryUrl() or discovery_url - if final_discovery_url: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials, - discoveryServiceUrl=final_discovery_url) - else: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) - else: - self._service = service + self._discovery_url = self._environmentDiscoveryUrl() or discovery_url self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and (should_resolve or - in_gke): + if (coordinator_name and not coordinator_address and + (should_resolve or in_gke)): self._start_local_server() else: self._coordinator_address = coordinator_address @@ -270,7 +291,8 @@ class TPUClusterResolver(ClusterResolver): # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) + service = self._tpuService() + request = service.projects().locations().nodes().get(name=full_name) response = request.execute() if 'state' in response and response['state'] != 'READY': diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index ad4f6432630be44a7de6e778f55f1fb7fd66f307..478c82967ba993c0551113a38879f87d872517a3 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -459,10 +459,10 @@ class TPUClusterResolverTest(test.TestCase): del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - def testDiscoveryUrl(self): + def testEnvironmentDiscoveryUrl(self): os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' self.assertEqual('https://{api}.internal/{apiVersion}', - TPUClusterResolver._discoveryUrl()) + TPUClusterResolver._environmentDiscoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index fbdca497fcc3126d2086d289ebdb113370072d22..a63366e1361effe20787c197eddd66b5c0c96410 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -59,8 +59,6 @@ option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires M # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) -set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") -set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) @@ -72,25 +70,25 @@ endif() if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option # for targets that link ${CMAKE_THREAD_LIBS_INIT}. - find_package (Threads) + find_package (Threads REQUIRED) # Options for linking CUDA/CUDNN libraries - option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) + option(tensorflow_PATH_CUDA_LIB "Additional library search path for cudnn, nccl, culibos" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) - option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) - option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_CUDNN_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_CUDNN_LIB) + option(tensorflow_PATH_NCCL_LIB "Override PATH_CUDA_LIB for nccl" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_NCCL_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_NCCL_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_NCCL_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -210,14 +208,17 @@ endif() include(CheckCXXCompilerFlag) # OpenMP Support -CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) -if (GCC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") -endif() -CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) -if (MSVC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") -endif() +if (WIN32) + CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) + if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") + endif() +else (WIN32) + CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) + if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + endif() +endif (WIN32) # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) @@ -377,29 +378,19 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # later command will make use of the value in tensorflow_CUDA_VERSION - find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) - - # Test compatibility of compiler on CUDA - try_compile(CUDA_TEST_COMPILE_C - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - try_compile(CUDA_TEST_COMPILE_CXX - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) - message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + # minimum 9.1 in cuda version + find_package(CUDA 9.1 REQUIRED) + if(NOT CUDA_FOUND) + message(FATAL_ERROR "CUDA not found.") endif() - # by default we assume compute cabability 3.5 and 5.2. If you change this change it in - # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") + # use cmake internal CUDA_ARCH_NAME switch + # e.g. CUDA_ARCH_NAME="Auto" will autodetect + # CUDA_ARCH_NAME="All" will use all arches + cuda_select_nvcc_arch_flags(NVCC_ARCH_FLAGS ${CUDA_ARCH_NAME}) + list(APPEND CUDA_NVCC_FLAGS ${NVCC_ARCH_FLAGS}) + message(STATUS "Using CUDA arch flags: ${NVCC_ARCH_FLAGS_readable}") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) @@ -423,43 +414,94 @@ if (tensorflow_ENABLE_GPU) else (WIN32) set(CUDNN_INCLUDE "${tensorflow_CUDNN_INCLUDE}") - find_library(nccl_STATIC_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT nccl_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl.so PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT nccl_LIBRARY) message(FATAL_ERROR "NCCL is required for GPU-build") - else (NOT nccl_STATIC_LIBRARY) - message("nccl-static: ${nccl_STATIC_LIBRARY}") + else (NOT nccl_LIBRARY) + message("nccl: ${nccl_LIBRARY}") # something like /usr/lib64/libnccl_static.a - endif (NOT nccl_STATIC_LIBRARY) - - find_library(cudnn_STATIC_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT cudnn_STATIC_LIBRARY) + endif (NOT nccl_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn.so PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT cudnn_LIBRARY) message(FATAL_ERROR "CUDNN is required for GPU-build") - else (NOT cudnn_STATIC_LIBRARY) - message("cudnn-static: ${cudnn_STATIC_LIBRARY}") - endif (NOT cudnn_STATIC_LIBRARY) - - find_library(culibos_STATIC_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT culibos_STATIC_LIBRARY) + else (NOT cudnn_LIBRARY) + file(READ ${CUDNN_INCLUDE}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + # fetch cudnn version + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") + if(NOT CUDNN_VERSION_MAJOR) + set(CUDNN_VERSION "???") + else() + set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") + endif() + message(STATUS "cudnn library: ${cudnn_LIBRARY} (found version: \"${CUDNN_VERSION}\")") + endif (NOT cudnn_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + # shared first (if exists) else static one + find_library(culibos_LIBRARY NAMES libculibos.so libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + # only static version + find_library(culibos_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT culibos_LIBRARY) message(FATAL_ERROR "CULIBOS is required for GPU-build") - else (NOT culibos_STATIC_LIBRARY) - message("culibos-static: ${culibos_STATIC_LIBRARY}") - endif (NOT culibos_STATIC_LIBRARY) + else (NOT culibos_LIBRARY) + message("culibos: ${culibos_LIBRARY}") + endif (NOT culibos_LIBRARY) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} - ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) + ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_LIBRARY} ${culibos_LIBRARY} ${nccl_LIBRARY}) endif (WIN32) include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. - string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + string(REPLACE "." "" short_CUDA_VER ${CUDA_VERSION}) + + # List of enumerated CUDA caps + string(REPLACE " " ";" NVCC_ARCH_LIST "${NVCC_ARCH_FLAGS_readable}") + set(list ${NVCC_ARCH_LIST}) + + # Construct capability string + foreach(NVCC_ARCH ${NVCC_ARCH_LIST}) + if (NVCC_ARCH MATCHES "sm_") + string(REGEX REPLACE "^.sm*" "" NVCC_ARCH ${NVCC_ARCH}) + math(EXPR NVCC_ARCH_MAJOR "${NVCC_ARCH} / 10") + math(EXPR NVCC_ARCH_MINOR "(${NVCC_ARCH} - (${NVCC_ARCH_MAJOR}*10))") + if (TF_CUDA_CAP) + set(TF_CUDA_CAP "${TF_CUDA_CAP},CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + else (TF_CUDA_CAP) + set(TF_CUDA_CAP "CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + endif (TF_CUDA_CAP) + endif() + endforeach() # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" + "#define TF_CUDA_CAPABILITIES ${TF_CUDA_CAP}\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" - "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" + "#define TF_CUDNN_VERSION \"64_${CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -494,14 +536,14 @@ if (tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll cudart_dll_name=cudart64_${short_CUDA_VER}.dll - cuda_version_number=${tensorflow_CUDA_VERSION} + cuda_version_number=${CUDA_VERSION} nvcuda_dll_name=nvcuda.dll cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=${tensorflow_CUDA_VERSION} - cudnn_version_number=${tensorflow_CUDNN_VERSION}) + cuda_version_number=${CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake index c6c5021f60b38ed05a19f3e439c9810251841f76..4546dbdecc0dbc36f17cc727345e0762718b5165 100644 --- a/tensorflow/contrib/cmake/external/abseil_cpp.cmake +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -20,6 +20,7 @@ if (systemlib_ABSEIL_CPP) absl_dynamic_annotations absl_malloc_internal absl_throw_delegate + absl_int128 absl_strings str_format_internal absl_bad_optional_access) @@ -50,6 +51,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/Release/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/Release/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) @@ -60,6 +62,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/absl_dynamic_annotations.lib ${abseil_cpp_BUILD}/absl/base/absl_malloc_internal.lib ${abseil_cpp_BUILD}/absl/base/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) @@ -71,6 +74,7 @@ else (systemlib_ABSEIL_CPP) ${abseil_cpp_BUILD}/absl/base/libabsl_dynamic_annotations.a ${abseil_cpp_BUILD}/absl/base/libabsl_malloc_internal.a ${abseil_cpp_BUILD}/absl/base/libabsl_throw_delegate.a + ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index d94b703700cfcd9ecae7f1d2718ba33ffd82c176..96160568fa79291a7b391761373e1eaf0f70974e 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -57,6 +57,7 @@ tensorflow/python/ops tensorflow/python/ops/distributions tensorflow/python/ops/linalg tensorflow/python/ops/losses +tensorflow/python/ops/signal tensorflow/python/platform tensorflow/python/profiler tensorflow/python/profiler/internal @@ -377,8 +378,6 @@ tensorflow/contrib/seq2seq/python/ops tensorflow/contrib/session_bundle tensorflow/contrib/session_bundle/example tensorflow/contrib/signal -tensorflow/contrib/signal/python -tensorflow/contrib/signal/python/ops tensorflow/contrib/slim tensorflow/contrib/slim/python tensorflow/contrib/slim/python/slim diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 88b4a6165c0f9171ec7cc169bc099c7db1549ee7..d66e39ac07c7b7c9423fa7e878a9cefd94b867bd 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,14 +68,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index ef337b3a15c4f58fe183af78d34376b3ed27099a..9cfa8b90749280b6aa815cc210941c75bd5e16c5 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -89,7 +89,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index ef487d3509bf3c9bfaf0b117998e6b121543c1c6..df7b854afcca1a0bed660624152f465d4bf3b25f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -373,8 +373,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index 8d13dc7316a693657f1b6e102830808d35372fe9..3b49755afcf0753d31c0ce506dce42709b1ee8bc 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -49,7 +48,7 @@ class XLACompileContextTest(test.TestCase): histogram_summary = summary.histogram('histogram_summary', dummy_tensor) image_summary = summary.image('image_summary', dummy_tensor) scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) summary.merge( [ audio_summary, histogram_summary, image_summary, scalar_summary, diff --git a/tensorflow/contrib/copy_graph/python/__init__.py b/tensorflow/contrib/copy_graph/python/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/__init__.py +++ b/tensorflow/contrib/copy_graph/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/copy_graph/python/util/__init__.py b/tensorflow/contrib/copy_graph/python/util/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/util/__init__.py +++ b/tensorflow/contrib/copy_graph/python/util/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index 57ffaa87e45559a6ecf4c8059e5a6cdee8b8b664..670b54943277806c47bfd6c6bc9b345db4bb1448 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -63,8 +63,8 @@ cuda_py_test( ], shard_count = 6, tags = [ - "no_oss", # b/117989214 "noasan", # http://b/62067814 + "requires-gpu-sm35", ], ) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 5d8c6191f8db9f96532aa78e4790a4665d3b4877..5320232268657fa73bcd3e86da49d6525e9b8db5 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -24,6 +24,10 @@ @@CudnnGRUSaveable @@CudnnRNNReluSaveable @@CudnnRNNTanhSaveable +@@CudnnParamsFormatConverterLSTM +@@CudnnParamsFormatConverterGRU +@@CudnnParamsFormatConverterTanh +@@CudnnParamsFormatConverterRelu """ from __future__ import absolute_import @@ -48,6 +52,10 @@ _allowed_symbols = [ "CudnnGRUSaveable", "CudnnRNNReluSaveable", "CudnnRNNTanhSaveable", + "CudnnParamsFormatConverterLSTM", + "CudnnParamsFormatConverterGRU", + "CudnnParamsFormatConverterTanh", + "CudnnParamsFormatConverterRelu", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index c59d3682d404e032d9f4bf81ef54ab456341cefa..ae839108ebec31b70b687e5ff3e99c7d5a9b560e 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -202,12 +202,13 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): dtype=dtype) random_seed.set_random_seed(1234) params_size_t = model.params_size() - params = variables.Variable( + params = variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), dtype=dtype, validate_shape=False) saveable = _CreateParamsSavable(params, model) - weights, biases = saveable._OpaqueParamsToCanonical() + weights, biases = saveable.format_converter._opaque_to_cu_canonical( + saveable._variables) reset_params = state_ops.assign( params, array_ops.zeros([params_size_t], dtype=dtype), @@ -248,7 +249,7 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): params_size_t = model.params_size() names = ["rnn_1", "rnn_2"] param_vars = [ - variables.Variable( + variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), dtype=dtype, validate_shape=False) for name in names @@ -256,8 +257,10 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): saveables = [] for name, params in zip(names, param_vars): saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0]._OpaqueParamsToCanonical() - weights2, biases2 = saveables[1]._OpaqueParamsToCanonical() + weights1, biases1 = saveables[0].format_converter._opaque_to_cu_canonical( + saveables[0]._variables) + weights2, biases2 = saveables[1].format_converter._opaque_to_cu_canonical( + saveables[1]._variables) reset_params = [ state_ops.assign( params, @@ -304,7 +307,7 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): direction=direction, dtype=dtype) params_size_t = model.params_size() - params = variables.Variable( + params = variables.VariableV1( array_ops.ones([params_size_t], dtype=dtype), validate_shape=False, dtype=dtype) @@ -422,21 +425,21 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): cudnn_rnn_ops.CUDNN_LSTM, constant_op.constant([4]), 200, 200, direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() with self.assertRaisesRegexp( ValueError, "Shape must be rank 0 but is rank 1"): model = _CreateModel( cudnn_rnn_ops.CUDNN_LSTM, 4, constant_op.constant([200]), 200, direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() with self.assertRaisesRegexp( ValueError, "Shape must be rank 0 but is rank 1"): model = _CreateModel( cudnn_rnn_ops.CUDNN_LSTM, 4, 200, constant_op.constant([200]), direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + _ = model.params_size() class CudnnRNNTestInference(TensorFlowTestCase): @@ -458,7 +461,7 @@ class CudnnRNNTestInference(TensorFlowTestCase): params_size_t = model.params_size() input_data = array_ops.ones([seq_length, batch_size, input_size]) input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.Variable( + params = variables.VariableV1( array_ops.ones([params_size_t]), validate_shape=False) if has_input_c: input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) @@ -584,20 +587,20 @@ class CudnnRNNTestTraining(TensorFlowTestCase): dtype=dtype, dropout=dropout) params_size_t = model.params_size() - input_data = variables.Variable( + input_data = variables.VariableV1( random_ops.random_uniform( [seq_length, batch_size, input_size], dtype=dtype), dtype=dtype) - input_h = variables.Variable( + input_h = variables.VariableV1( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), dtype=dtype) - params = variables.Variable( + params = variables.VariableV1( random_ops.random_uniform([params_size_t], dtype=dtype), validate_shape=False, dtype=dtype) if has_input_c: - input_c = variables.Variable( + input_c = variables.VariableV1( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), dtype=dtype) @@ -639,7 +642,8 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTraining(self): + def DISABLED_testSimpleTraining(self): + # TODO(jamesqin): fix b/117989214 test_configs = [ { "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index f09466b631f69d6234573dd5eafada650421c117..60229af374be869005139921483793156e5e7a05 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -27,5 +27,10 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibl from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterGRU +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterLSTM +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterRelu +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterTanh from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index a324c6e7d76223aaa6514e695e4ff8444db455d0..8bbcc7cd0397a5339a69e4e44528f0e56584043a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -388,11 +388,11 @@ class _CudnnRNN(base_layer.Layer): output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. Raises: - ValueError: initial_state is not a tuple. + TypeError: initial_state is not a tuple. """ if initial_state is not None and not isinstance(initial_state, tuple): - raise ValueError("Invalid initial_state type: %s, expecting tuple.", - type(initial_state)) + raise TypeError("Invalid initial_state type: %s, expecting tuple." % + initial_state) dtype = self.dtype inputs = ops.convert_to_tensor(inputs, dtype=dtype) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 2c92f31788378c2a9f01183bc04b035668b59b59..d06d0c6bdaa113089c4d4239a6d4ed216ddd01a8 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -74,7 +74,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): - """Cudnn Compatible GRUCell. + r"""Cudnn Compatible GRUCell. A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by @@ -177,172 +177,60 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): return new_h, new_h -# TODO(yaozhang): make sure we only save the canonical version of params and -# don't save the platform-specific version to avoid potential race -# conditions where params is updated by both versions when being restored. -# Currently, checkpointing will function properly, despite that we save both -# versions, because Saver restores customized savables after Variables. -# However, it is good to not rely on this restoring order of Saver and to -# avoid unnecessary storage. Add a test to check only the canonical version is -# saved. -class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): - """Abstract SaveableObject implementation handling Cudnn opaque params.""" +class CudnnParamsFormatConverter(object): + """Abstract class that converts between params of Cudnn Rnn and TF Rnn.""" def __init__(self, - opaque_params, num_layers, num_units, input_size, input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - scope=None, - name="cudnn_rnn_saveable"): - """Creates a CudnnOpaqueParamsSaveable object. - - CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file - and is used to save/restore the weights and biases parameters in a - canonical format which is directly consumable by platform-independent tf - RNN cells. Parameters are saved as tensors layer by layer with weight - tensors followed by bias tensors, and forward direction followed by - backward direction (if applicable). When restoring, a user could name - param_variables as desired, and restore weight and bias tensors to these - variables. - - For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per - bias for each layer: tensor 0 is applied to the input from the previous - layer and tensor 1 to the recurrent input. - - For CudnnLSTM, there are 8 tensors per weight and per bias for each - layer: tensor 0-3 are applied to the input from the previous layer and - tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; - tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; - tensor 3 and 7 the output gate. - - For CudnnGRU, there are 6 tensors per weight and per bias for each layer: - tensor 0-2 are applied to the input from the previous layer and - tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; - tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + direction=CUDNN_RNN_UNIDIRECTION): + """Constructor. Args: - opaque_params: a variable, Cudnn RNN opaque params. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be one + of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' + (default) always applies a linear projection of input onto RNN hidden + state. (standard RNN behavior). * 'skip_input' is only allowed when + input_size == num_units; * 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' - scope: string of VariableScope, the scope of equivalent subgraph - consisting only platform-independent tf RNN cells. - name: the name of the CudnnOpaqueParamsSaveable object. + 'unidirectional' or 'bidirectional' """ - # Define in subclasses. self._num_layers = num_layers self._input_size = input_size self._num_units = num_units self._input_mode = input_mode self._direction = direction - if scope is not None: - scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope - self._scope = scope_name or None - else: - self._scope = None - - self._variables = opaque_params self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - weights, biases = self._OpaqueParamsToCanonical() - (weights, weight_names), (biases, bias_names) = self._TransformCanonical( - weights, biases) - # We currently don't use slice_spec. It might be useful in a distributed - # setting where each parameter server node stores a slice of variable, - # instead of having the master pull all slices and then save them. - slice_spec = "" - params = weights + biases - self._weight_names = weight_names - self._bias_names = bias_names - self._param_names = weight_names + bias_names - prefixed_param_names = weight_names + bias_names - if self._scope: - prefixed_param_names = [ - "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] - specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, prefixed_param_names) - ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) - - def restore(self, restored_tensors, restored_shapes): - weights, biases = self._ReverseTransformCanonical(restored_tensors) - weights = [array_ops.reshape(w, [-1]) for w in weights] - opaque_params = self._CanonicalToOpaqueParams(weights, biases) - - return state_ops.assign( - self._variables, opaque_params, validate_shape=False) + def tf_canonical_to_opaque(self, tf_canonicals): + r"""Converts tf canonical weights to cudnn opaque param.""" + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] + opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) + return opaque_params - def _checkpointable_save(self, save_buffer): - weights, biases = self._OpaqueParamsToCanonical() - with ops.device("gpu:0"): - (weights, _), (biases, _) = self._TransformCanonical( - weights, biases) - for name, tensor in zip(self._param_names, weights + biases): - save_buffer[name] = array_ops.identity(tensor) + def opaque_to_tf_canonical(self, opaque_param): + r"""Converts cudnn opaque param to tf canonical weights.""" + cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) + weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) + return weights, biases - def _checkpointable_restore(self, restore_buffer): - tensors = [array_ops.identity(restore_buffer[name]) - for name in self._param_names] - return self.restore( - restored_tensors=tensors, - restored_shapes=None # Unused - ) - - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. - - When saving or restoring, converts to or from the opaque buffer - format. Weights are saved and loaded in the configuration expected by - cuDNN-compatible cells. - - Args: - checkpointable: An object inheriting from `CheckpointableBase` to add - dependencies too (typically the cuDNN `Layer`). - dtype: The dtype for the canonical parameter Tensors. - """ - split_dependencies = split_dependency.split_dependency( - component_names=self._param_names, - component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) - - def _checkpointable_track_params(self, checkpointable, params): - """Tracks parameters in a canonical configuration.""" - return # NotImplementedError raised by the Layer. - - def _TFCanonicalNamePrefix(self, layer, is_fwd=True): - if self._direction == CUDNN_RNN_UNIDIRECTION: - return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) - else: - if is_fwd: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % - (layer, self._rnn_cell_name)) - else: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % - (layer, self._rnn_cell_name)) - - def _OpaqueParamsToCanonical(self): + def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. + Args: + opaque_param: An opaque tensor storing cudnn rnn params (weights and + biases). Returns: 2 list for weights and biases respectively. """ @@ -351,14 +239,14 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, - params=self._variables, + params=opaque_param, num_params=self._num_params, rnn_mode=self._rnn_mode, input_mode=self._input_mode, direction=self._direction) return (weights, biases) - def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): + def _cu_canonical_to_opaque(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. Args: @@ -378,7 +266,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): input_mode=self._input_mode, direction=self._direction) - def _TransformCanonical(self, cu_weights, cu_biases): + def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. The elements of argument lists are laid out in the following format: @@ -398,46 +286,43 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. Returns: - 2 tuples, one for weights and the other for bias. - Each tuple has two lists: the 1st for transformed tf canonical tensors - and the 2nd for the names of the tensors under which they are saved. + 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] - tf_weights_names, tf_bias_names = [], [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num for i in range(self._num_layers): - layer_weights = cu_weights[i * layer_weights_num: - (i + 1) * layer_weights_num] + layer_weights = cu_weights[i * layer_weights_num:(i + 1) * + layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - prefix = self._TFCanonicalNamePrefix(i) - self._TransformSingleLayerCanonical(layer_weights, layer_biases, prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + layer_weights, layer_biases, tf_weights, tf_biases) else: - fw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=True) - bw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=False) - fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] fw_biases = layer_biases[:len(layer_biases) // 2] bw_biases = layer_biases[len(layer_biases) // 2:] - self._TransformSingleLayerCanonical(fw_weights, fw_biases, fw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - - self._TransformSingleLayerCanonical(bw_weights, bw_biases, bw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - return (tf_weights, tf_weights_names), (tf_biases, tf_bias_names) - - def _TransformSingleLayerCanonical(self, cu_weights, cu_biases, prefix, - tf_weights, tf_weights_names, tf_biases, - tf_bias_names): + self._cu_canonical_to_tf_canonical_single_layer( + fw_weights, + fw_biases, + tf_weights, + tf_biases, + ) + + self._cu_canonical_to_tf_canonical_single_layer( + bw_weights, + bw_biases, + tf_weights, + tf_biases, + ) + return (tf_weights, tf_biases) + + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -447,15 +332,12 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Args: cu_weights: a list of tensors, single layer weights. cu_biases: a list of tensors, single layer biases. - prefix: the shared prefix of all tensor names. tf_weights: a list where transformed weights are stored. - tf_weights_names: a list where names of transformed weights are stored. tf_biases: a list where transformed biases are stored. - tf_bias_names: a list where names of transformed biases are stored. """ raise NotImplementedError("Abstract method") - def _ReverseTransformCanonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -502,30 +384,27 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return cu_weights, cu_biases def _cudnn_to_tf_weights(self, *cu_weights): - r"""Stitching cudnn canonical weights to generate tf canonical weights.""" + r"""Stitches cudnn canonical weights to generate tf canonical weights.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_weights(self, layer, *tf_weights): - r"""Reverse the operations in StitchWeights().""" + r"""Reverses the operations in StitchWeights().""" raise NotImplementedError("Abstract method") def _cudnn_to_tf_biases(self, *biases): - r"""Stitching cudnn canonical biases to generate tf canonical biases.""" + r"""Stitches cudnn canonical biases to generate tf canonical biases.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_biases(self, *tf_biases): - r"""Reverse the operations in StitchBiases().""" + r"""Reverses the operations in StitchBiases().""" raise NotImplementedError("Abstract method") -class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn LSTM opaque params.""" - +class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF LSTM.""" _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) - def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order return [i_g, c_g, f_g, o_g] @@ -603,44 +482,16 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # Return ifco order for Cudnn LSTM. return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): - (w,) = self._cudnn_to_tf_weights(*weights) - (b,) = self._cudnn_to_tf_biases(*biases) - + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): + (w,) = self._cudnn_to_tf_weights(*cu_weights) + (b,) = self._cudnn_to_tf_biases(*cu_biases) tf_weights.append(w) - tf_weights_names.append(prefix + "/kernel") - tf_biases.append(b) - tf_bias_names.append(prefix + "/bias") - - def _checkpointable_track_params(self, checkpointable, params): - """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" - biases = [] - weights = [] - for name in self._weight_names: - weights.append(params[name]) - for name in self._bias_names: - biases.append(params[name]) - assert len(params) == len(weights) + len(biases) - if len(weights) == 1 and len(biases) == 1: - # For single-layer cells, allow substituting a cell with no MultiRNNCell - # wrapping. - kernel, = weights # pylint: disable=unbalanced-tuple-unpacking - bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access - assert len(biases) == len(weights) - for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access - cell.bias = bias - cell.kernel = kernel -class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn GRU opaque params.""" +class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF GRU.""" _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER @@ -702,29 +553,18 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): b_ri, b_rr = array_ops.split(br, 2, axis=0) return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): # pylint: disable=invalid-name - W_ir, w_h, r_h = self._cudnn_to_tf_weights(*weights) - b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*biases) + W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) + b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) # pylint: enable=invalid-name - tf_weights.extend([W_ir, w_h, r_h]) - tf_weights_names.append(prefix + "/gates/kernel") - tf_weights_names.append(prefix + "/candidate/input_projection/kernel") - tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") - tf_biases.extend([b_ir, b_wh, b_rh]) - tf_bias_names.append(prefix + "/gates/bias") - tf_bias_names.append(prefix + "/candidate/input_projection/bias") - tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") - -class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) +class CudnnParamsFormatConverterBasic(CudnnParamsFormatConverterLSTM): + """Helper class that converts between params of Cudnn and TF Relu/Tanh RNN.""" def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -766,18 +606,270 @@ class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): return b_i, b_h -class CudnnRNNTanhSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" +class CudnnParamsFormatConverterTanh(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Tanh RNN.""" _rnn_mode = CUDNN_RNN_TANH _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER -class CudnnRNNReluSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Relu opaque params.""" +class CudnnParamsFormatConverterRelu(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Relu RNN.""" _rnn_mode = CUDNN_RNN_RELU _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER +# TODO(yaozhang): make sure we only save the canonical version of params and +# don't save the platform-specific version to avoid potential race +# conditions where params is updated by both versions when being restored. +# Currently, checkpointing will function properly, despite that we save both +# versions, because Saver restores customized savables after Variables. +# However, it is good to not rely on this restoring order of Saver and to +# avoid unnecessary storage. Add a test to check only the canonical version is +# saved. +class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): + """Abstract SaveableObject implementation handling Cudnn opaque params.""" + + def __init__(self, + opaque_params, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + scope=None, + name="cudnn_rnn_saveable"): + """Creates a CudnnOpaqueParamsSaveable object. + + CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file + and is used to save/restore the weights and biases parameters in a + canonical format which is directly consumable by platform-independent tf + RNN cells. Parameters are saved as tensors layer by layer with weight + tensors followed by bias tensors, and forward direction followed by + backward direction (if applicable). When restoring, a user could name + param_variables as desired, and restore weight and bias tensors to these + variables. + + For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per + bias for each layer: tensor 0 is applied to the input from the previous + layer and tensor 1 to the recurrent input. + + For CudnnLSTM, there are 8 tensors per weight and per bias for each + layer: tensor 0-3 are applied to the input from the previous layer and + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + + For CudnnGRU, there are 6 tensors per weight and per bias for each layer: + tensor 0-2 are applied to the input from the previous layer and + tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; + tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + + Args: + opaque_params: a variable, Cudnn RNN opaque params. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + scope: string of VariableScope, the scope of equivalent subgraph + consisting only platform-independent tf RNN cells. + name: the name of the CudnnOpaqueParamsSaveable object. + """ + # Define in subclasses. + self._num_layers = num_layers + self._input_size = input_size + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + if scope is not None: + scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope + self._scope = scope_name or None + else: + self._scope = None + + self._variables = opaque_params + self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + # Defined in subclasses. + self._format_converter = None + + tf_weights, tf_biases = ( + self.format_converter.opaque_to_tf_canonical(self._variables)) + tf_weight_names, tf_bias_names = self._tf_canonical_names() + # We currently don't use slice_spec. It might be useful in a distributed + # setting where each parameter server node stores a slice of variable, + # instead of having the master pull all slices and then save them. + slice_spec = "" + params = tf_weights + tf_biases + self._weight_names = tf_weight_names + self._bias_names = tf_bias_names + self._param_names = tf_weight_names + tf_bias_names + prefixed_param_names = tf_weight_names + tf_bias_names + if self._scope: + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names + ] + specs = [ + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, prefixed_param_names) + ] + super(CudnnOpaqueParamsSaveable, self).__init__( + array_ops.identity(self._variables), specs, name) + + @property + def format_converter(self): + if self._format_converter is None: + self._format_converter = self._format_converter_cls( + self._num_layers, self._num_units, self._input_size, self._input_mode, + self._direction) + return self._format_converter + + def restore(self, restored_tensors, restored_shapes): + opaque_params = self.format_converter.tf_canonical_to_opaque( + restored_tensors) + return state_ops.assign( + self._variables, opaque_params, validate_shape=False) + + def _checkpointable_save(self, save_buffer): + weights, biases = self.format_converter.opaque_params_to_tf_canonical( + self._variables) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [ + array_ops.identity(restore_buffer[name]) for name in self._param_names + ] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = split_dependency.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + + def _tf_canonical_names(self): + tf_weights_names, tf_biases_names = [], [] + for i in range(self._num_layers): + if self._direction == CUDNN_RNN_UNIDIRECTION: + prefix = self._tf_canonical_name_prefix(i) + self._tf_canonical_names_single_layer(prefix, tf_weights_names, + tf_biases_names) + else: + fwd_prefix = self._tf_canonical_name_prefix(i, is_fwd=True) + bak_prefix = self._tf_canonical_name_prefix(i, is_fwd=False) + + self._tf_canonical_names_single_layer(fwd_prefix, tf_weights_names, + tf_biases_names) + self._tf_canonical_names_single_layer(bak_prefix, tf_weights_names, + tf_biases_names) + return tf_weights_names, tf_biases_names + + def _tf_canonical_name_prefix(self, layer, is_fwd=True): + if self._direction == CUDNN_RNN_UNIDIRECTION: + return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) + else: + if is_fwd: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % + (layer, self._rnn_cell_name)) + else: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % + (layer, self._rnn_cell_name)) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_biases_names): + raise NotImplementedError("Abstract method") + + +class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn LSTM opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterLSTM + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/kernel") + tf_bias_names.append(prefix + "/bias") + + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + + +class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn GRU opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterGRU + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/gates/kernel") + tf_weights_names.append(prefix + "/candidate/input_projection/kernel") + tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") + + tf_bias_names.append(prefix + "/gates/bias") + tf_bias_names.append(prefix + "/candidate/input_projection/bias") + tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") + + +class CudnnRNNTanhSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterTanh + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + +class CudnnRNNReluSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterRelu + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and training. But it is possible that the layout of the parameter buffers @@ -850,7 +942,7 @@ def _get_num_params(rnn_mode, num_layers, direction): elif rnn_mode == CUDNN_RNN_TANH: num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: - raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) + raise ValueError("Invalid \'rnn_mode\': %s" % rnn_mode) num_params = num_layers * num_params_per_layer if direction != CUDNN_RNN_UNIDIRECTION: num_params *= 2 @@ -918,7 +1010,7 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 is not "1": + if use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1582,7 +1674,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): """ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s", direction) + raise ValueError("Invalid direction: %s" % direction) super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index f82453f3b5ea01b8bb64a70bd49f5e3e831bb4e2..a938f8629d8210b4b512338a040340f21d3ef594 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -46,6 +46,9 @@ Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` Take a very simple model consisting of a single layer: ```python +import tensorflow as tf +from tensorflow import keras + inputs = tf.keras.layers.Input(shape=(1,)) predictions = tf.keras.layers.Dense(1)(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=predictions) @@ -90,8 +93,8 @@ Similarly, we can also call `evaluate` and `predict` as before using appropriate datasets. ```python -model.evaluate(eval_dataset) -model.predict(predict_dataset) +model.evaluate(eval_dataset, steps=1) +model.predict(predict_dataset, steps=1) ``` That's all you need to train your model with Keras on multiple GPUs with diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 22736c799d276033c0ddc112d17e898be944c933..4094e52169aab0b46da4f62087ddac4f750039a4 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -374,9 +374,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", - # TODO(b/118820960): Re-enable this test in guitar. - "manual", - "noguitar", ], ) @@ -470,6 +467,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -492,6 +490,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", ], ) @@ -757,8 +756,6 @@ cuda_py_test( "no_oss", # TODO(b/117919883): Fix python error. "no_pip", "no_windows_gpu", - # TODO(b/118815591): Re-enable this test in guitar.) - "noguitar", "notsan", ], ) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index b311644cb22898df082b0c803d1a8960fe159c98..d38bdb592a303d23871b48d80868917efc01dcd1 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -69,7 +69,7 @@ class CheckpointUtilsWithDistributionStrategyTest( with ops.Graph().as_default() as g, distribution.scope(): if in_replica_mode: - distribution.call_for_each_replica(init_and_verify, g) + distribution.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index d9339f8f75acda3695d33c55409e921a9627bac7..efa99d1fc52e8facfaeb61f98b5e649a18f6a3cf 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -205,7 +205,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) def configure(self, diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 19b59513d81c5b0cebf5e44aa66b110db86a91c8..e3d919dd0d482f49d9a934c879e9adad25c03f86 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -54,8 +54,6 @@ class CollectiveAllReduceStrategyTestBase( self._run_options = config_pb2.RunOptions() self._run_options.experimental.collective_graph_key = 6 - self._sess_config = config_pb2.ConfigProto() - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -66,9 +64,10 @@ class CollectiveAllReduceStrategyTestBase( def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( num_gpus_per_worker=num_gpus) + session_config = config_pb2.ConfigProto() if task_type and task_id is not None: distribution.configure( - session_config=self._sess_config, + session_config=session_config, cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) @@ -82,14 +81,16 @@ class CollectiveAllReduceStrategyTestBase( distribution._collective_keys = collective_keys distribution._cross_tower_ops._collective_keys = collective_keys if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + return distribution, 'grpc://' + self._cluster_spec[task_type][ + task_id], session_config else: - return distribution, '' + return distribution, '', session_config def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ d.scope(): l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) @@ -117,7 +118,7 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -154,7 +155,8 @@ class CollectiveAllReduceStrategyTestBase( return error_after < error_before def _test_complex_model(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) def model_fn(): """Mnist model with synthetic input.""" @@ -193,7 +195,7 @@ class CollectiveAllReduceStrategyTestBase( return train_op with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess: with d.scope(): train_op = d.call_for_each_replica(model_fn) @@ -204,10 +206,10 @@ class CollectiveAllReduceStrategyTestBase( return True def _test_variable_initialization(self, task_type, task_id, num_gpus): - distribution, master_target = self._get_test_object(task_type, task_id, - num_gpus) + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.cached_session(config=self._sess_config, + self.cached_session(config=config, target=master_target) as sess, \ distribution.scope(): diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 63a163e76cdd99c73399c657cbe9bc3d010369d2..a51371654031e32d084e2b0e8ae345bb2c166ae8 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -335,17 +335,13 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) -# Note that we disable prefetching for testing since prefetching makes -# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index bae0f474d27b3256358f8ac08cdd6b5f04be56c5..b5b349aa64e4df46bd11ab8f01ec488afd3a26a7 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -62,26 +62,26 @@ def validate_destinations(destinations): raise ValueError("destinations can not be empty") -def _make_tensor_into_per_device(input_tensor): - """Converts a single tensor into a PerDevice object.""" +def _make_tensor_into_per_replica(input_tensor): + """Converts a single tensor into a PerReplica object.""" if isinstance(input_tensor, (tuple, list)): - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, " + raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, " "got %r but expected a object that is not a tuple or list." % (input_tensor,)) - if isinstance(input_tensor, value_lib.PerDevice): + if isinstance(input_tensor, value_lib.PerReplica): return input_tensor try: device = input_tensor.device except AttributeError: - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object " + raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " "because it doesn't have device set.") - return value_lib.PerDevice({device: input_tensor}) + return value_lib.PerReplica({device: input_tensor}) def _normalize_value_destination_pairs(value_destination_pairs): - """Converts each tensor into a PerDevice object in the input list.""" + """Converts each tensor into a PerReplica object in the input list.""" result = [] if not isinstance(value_destination_pairs, (list, tuple)): raise ValueError("`value_destination_pairs` should be a list or tuple") @@ -93,8 +93,8 @@ def _normalize_value_destination_pairs(value_destination_pairs): raise ValueError("Each element of `value_destination_pairs` should be a " "tuple of size 2.") - per_device = _make_tensor_into_per_device(pair[0]) - result.append((per_device, pair[1])) + per_replica = _make_tensor_into_per_replica(pair[0]) + result.append((per_replica, pair[1])) return result @@ -105,7 +105,7 @@ def _validate_value_destination_pairs(value_destination_pairs): if not isinstance(value_destination_pairs, (list, tuple)): return False if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): return False - if not all([isinstance(v[0], value_lib.PerDevice) + if not all([isinstance(v[0], value_lib.PerReplica) for v in value_destination_pairs]): return False return True @@ -149,26 +149,16 @@ def _simple_broadcast(value, destinations): return value_lib.Mirrored(index) -def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, +def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 - for v in per_device_value._index.values(): # pylint: disable=protected-access - if isinstance(v, value_lib.MapOutput): - v_list = v.get() - if not v_list: - continue - count += len(v_list) - # Sum within each device before aggregating across devices. - # TODO(yuefengz): Check whether it helps to use accumulation_fn here. - v = cross_tower_utils.aggregate_tensors_or_indexed_slices( - v_list, math_ops.add_n) - else: - count += 1 + for v in per_replica_value._index.values(): # pylint: disable=protected-access + count += 1 all_values.append(v) if not all_values: - raise ValueError("`per_device_value` must be non-empty") + raise ValueError("`per_replica_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): @@ -189,8 +179,8 @@ class CrossDeviceOps(object): def __init__(self): pass - def reduce(self, aggregation, per_device_value, destinations): - """Reduce `per_device_value` to `destinations`. + def reduce(self, aggregation, per_replica_value, destinations): + """Reduce `per_replica_value` to `destinations`. It runs the reduction operation defined by `aggregation` and put the result on `destinations`. @@ -198,23 +188,23 @@ class CrossDeviceOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - per_device_value: a PerDevice object or a tensor with device set. + per_replica_value: a PerReplica object or a tensor with device set. destinations: the reduction destinations. Returns: a Mirrored object. Raises: - ValueError: if per_device_value is not a PerDevice object. + ValueError: if per_replica_value is not a PerReplica object. """ - if not isinstance(per_device_value, value_lib.PerDevice): - per_device_value = _make_tensor_into_per_device(per_device_value) + if not isinstance(per_replica_value, value_lib.PerReplica): + per_replica_value = _make_tensor_into_per_replica(per_replica_value) validate_destinations(destinations) - return self._reduce(aggregation, per_device_value, destinations) + return self._reduce(aggregation, per_replica_value, destinations) def batch_reduce(self, aggregation, value_destination_pairs): - """Reduce PerDevice objects in a batch. + """Reduce PerReplica objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. @@ -222,7 +212,7 @@ class CrossDeviceOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - value_destination_pairs: a list or a tuple of tuples of PerDevice objects + value_destination_pairs: a list or a tuple of tuples of PerReplica objects (or tensors with device set if there is one device) and destinations. Returns: @@ -230,11 +220,11 @@ class CrossDeviceOps(object): Raises: ValueError: if `value_destination_pairs` is not a list or a tuple of - tuples of PerDevice objects and destinations + tuples of PerReplica objects and destinations """ if not _validate_value_destination_pairs(value_destination_pairs): # If the first element of each pair is a tensor, we try to turn it into a - # PerDevice object. + # PerReplica object. value_destination_pairs = _normalize_value_destination_pairs( value_destination_pairs) @@ -256,7 +246,7 @@ class CrossDeviceOps(object): validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") @@ -286,13 +276,13 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): if check_destinations(destinations): devices = get_devices_from(destinations) else: - devices = get_devices_from(per_device_value) + devices = get_devices_from(per_replica_value) reduce_to_device = self.reduce_to_device or devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, + reduced = _simple_reduce(per_replica_value, reduce_to_device, self.accumulation_fn, aggregation) return self.broadcast(reduced, devices) @@ -303,7 +293,7 @@ class ReductionToOneDeviceCrossDeviceOps(CrossDeviceOps): ] -def _group_value_by_device(per_device_values): +def _group_value_by_device(per_replica_values): """Group values into sublists by their devices. This grouping is needed to call the all-reduce library because it expects a @@ -315,18 +305,18 @@ def _group_value_by_device(per_device_values): ] Args: - per_device_values: a list of PerDevice obejcts. + per_replica_values: a list of PerReplica obejcts. Returns: a list of lists, each sublist has components for its corresponding device of - PerDevice objects, paired with a None. + PerReplica objects, paired with a None. """ - destinations = per_device_values[0].devices + destinations = per_replica_values[0].devices grouped = [[] for _ in range(len(destinations))] - for per_device_value in per_device_values: + for per_replica_value in per_replica_values: # pylint: disable=protected-access - for i, v in enumerate(per_device_value._index.values()): - assert per_device_value.devices == destinations + for i, v in enumerate(per_replica_value._index.values()): + assert per_replica_value.devices == destinations grouped[i].append((v, None)) return grouped @@ -354,8 +344,8 @@ def _ungroup_and_make_mirrored(grouped_reduced, a list of Mirrored objects. """ index = [{} for _ in range(len(grouped_reduced[0]))] - for d, per_device_reduced in enumerate(grouped_reduced): - for i, (v, _) in enumerate(per_device_reduced): + for d, per_replica_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_replica_reduced): if aggregation == vs.VariableAggregation.MEAN: index[i][destinations[d]] = v / ( len(destinations) * num_between_graph_workers) @@ -567,13 +557,13 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossDeviceOps, self).__init__() - def _reduce(self, aggregation, per_device_value, destinations): + def _reduce(self, aggregation, per_replica_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( - per_device_value) - if (_devices_match(per_device_value, destinations) + per_replica_value) + if (_devices_match(per_replica_value, destinations) and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, [per_device_value])[0] + return self._batch_all_reduce(aggregation, [per_replica_value])[0] else: if contains_indexed_slices: logging.log_first_n( @@ -583,9 +573,9 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): if check_destinations(destinations): devices = get_devices_from(destinations) else: - devices = get_devices_from(per_device_value) + devices = get_devices_from(per_replica_value) reduce_to_device = devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, + reduced = _simple_reduce(per_replica_value, reduce_to_device, math_ops.add_n, aggregation) return self.broadcast(reduced, devices) @@ -609,16 +599,16 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " "agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_alg, self._num_packs, + (len(per_replica_values), self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - destinations = per_device_values[0].devices - grouped = _group_value_by_device(per_device_values) + destinations = per_replica_values[0].devices + grouped = _group_value_by_device(per_replica_values) device_grad_packs, tensor_packer = _pack_tensors( grouped, self._num_packs, self._agg_small_grads_max_bytes, @@ -639,7 +629,7 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): destinations, device_grad_packs)) reduced = _unpack_tensors(reduced, tensor_packer) - return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + return _ungroup_and_make_mirrored(reduced, per_replica_values[0].devices, aggregation) @@ -723,18 +713,18 @@ class MultiWorkerAllReduce(AllReduceCrossDeviceOps): validate_and_complete_spec(spec) for spec in all_reduce_spec ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All reduce algorithm in a batch.""" logging.log_first_n( logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " "and agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_spec, self._num_packs, + (len(per_replica_values), self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - destinations = sorted(per_device_values[0].devices) - device_grads = _group_value_by_device(per_device_values) + destinations = sorted(per_replica_values[0].devices) + device_grads = _group_value_by_device(per_replica_values) # The all reduce library requires fully defined shapes. # TODO(yuefengz): when tensor sharding is not needed, static shapes are not @@ -805,16 +795,16 @@ class CollectiveAllReduce(CrossDeviceOps): super(CollectiveAllReduce, self).__init__() # TODO(yuefengz, tucker): is indexed slices supported by collective ops? - def _reduce(self, aggregation, per_device_value, destinations): - if cross_tower_utils.contains_indexed_slices(per_device_value): + def _reduce(self, aggregation, per_replica_value, destinations): + if cross_tower_utils.contains_indexed_slices(per_replica_value): raise ValueError( "`IndexSlices` is not supported for Collective All-Reduce.") if context.executing_eagerly(): raise ValueError( "Eager execution is not supported for Collective All-Reduce") - all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] - if _devices_match(per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_replica_value])[0] + if _devices_match(per_replica_value, destinations): return all_reduced else: index = {} @@ -852,7 +842,7 @@ class CollectiveAllReduce(CrossDeviceOps): for t, v in value_destination_pairs ] - def _batch_all_reduce(self, aggregation, per_device_values): + def _batch_all_reduce(self, aggregation, per_replica_values): """All-reduce across all workers in a batch.""" if context.executing_eagerly(): raise ValueError( @@ -860,9 +850,9 @@ class CollectiveAllReduce(CrossDeviceOps): logging.log_first_n( logging.INFO, "Collective All-reduce invoked with batches size = %d, " - "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + "num_workers = %d" % (len(per_replica_values), self._num_workers), 10) - grouped_by_device = _group_value_by_device(per_device_values) + grouped_by_device = _group_value_by_device(per_replica_values) grouped_by_var = list(zip(*grouped_by_device)) # grouped_by_var is grouped by variables and takes the following format: @@ -892,7 +882,7 @@ class CollectiveAllReduce(CrossDeviceOps): new_device_grads = [list(x) for x in zip(*reduced_gv_list)] return _ungroup_and_make_mirrored( new_device_grads, - per_device_values[0].devices, + per_replica_values[0].devices, aggregation, num_between_graph_workers=self._num_workers) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 6a9e8e00c02411d6486f30146f7f7d86ecd2fa9c..3e274ba67ca6709a14f5391968f28b721e46b8a6 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -40,12 +40,12 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util -def _make_per_device(values, devices, regroup=False): +def _make_per_replica(values, devices, regroup=False): devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) - # We simulate the result of regroup called on PerDevice which strips the - # PerDevice wrapper if it has only one value. + # We simulate the result of regroup called on PerReplica which strips the + # PerReplica wrapper if it has only one value. if len(values) == 1 and regroup: with ops.device(devices[0]): placed_v = array_ops.identity(values[0]) @@ -56,7 +56,7 @@ def _make_per_device(values, devices, regroup=False): with ops.device(d): placed_v = array_ops.identity(v) index[d] = placed_v - return value_lib.PerDevice(index) + return value_lib.PerReplica(index) # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -122,11 +122,11 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): devices = distribution.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] - per_device = _make_per_device(values, devices) + per_replica = _make_per_replica(values, devices) mean = (len(devices) - 1.) / 2. values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = mean + 1. destination_mirrored = _fake_mirrored(1., devices) @@ -144,39 +144,41 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.MEAN, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations)) self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.MEAN, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations)) self._assert_values_equal( cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_device, + vs.VariableAggregation.SUM, per_replica, destinations=destinations), _fake_mirrored(mean * len(devices), destinations)) self._assert_values_equal( cross_tower_ops.reduce( vs.VariableAggregation.SUM, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), (per_device_2, d2)]), + cross_tower_ops.batch_reduce( + vs.VariableAggregation.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ]) self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), (per_device_2, d2)]), + cross_tower_ops.batch_reduce( + vs.VariableAggregation.SUM, + [(per_replica, d1), (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices), d1), _fake_mirrored(mean_2 * len(devices), d2) @@ -277,9 +279,9 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): devices = ["/cpu:0", "/gpu:0"] t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) result = cross_tower_ops_lib._simple_reduce( - per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) + per_replica, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -311,13 +313,14 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) t1 = _make_indexed_slices( [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(aggregation, - [(per_device, devices)]) + result = cross_tower_ops_instance.batch_reduce( + aggregation, [(per_replica, devices)]) else: - result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) + result = cross_tower_ops_instance.reduce( + aggregation, per_replica, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] @@ -478,11 +481,11 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_device = _make_per_device(values, devices, regroup=True) + per_replica = _make_per_replica(values, devices, regroup=True) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) + per_replica_2 = _make_per_replica(values_2, devices) mean_2 = np.array([mean[0] + 1.]) destination_mirrored = _fake_mirrored(1., devices) @@ -500,26 +503,26 @@ class MultiWorkerCollectiveAllReduceTest( self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.MEAN, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.MEAN, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.SUM, - per_device, + per_replica, destinations=destinations), _fake_mirrored(mean * len(devices) * num_workers, destinations), sess) self._assert_values_equal( collective_all_reduce.reduce( vs.VariableAggregation.SUM, - per_device_2, + per_replica_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), sess) @@ -528,16 +531,16 @@ class MultiWorkerCollectiveAllReduceTest( for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), - (per_device_2, d2)]), + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean, d1), _fake_mirrored(mean_2, d2) ], sess) self._assert_values_equal( collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), - (per_device_2, d2)]), + [(per_replica, d1), + (per_replica_2, d2)]), [ _fake_mirrored(mean * len(devices) * num_workers, d1), _fake_mirrored(mean_2 * len(devices) * num_workers, d2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 35324d15d4416364698390468d65d442f442ec50..50b3cf31e59d3bb9ab1471ffae174a04eb90ef8d 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -667,7 +667,5 @@ def contains_indexed_slices(value): return any(contains_indexed_slices(v) for v in value) elif isinstance(value, value_lib.DistributedValues): return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access - elif isinstance(value, value_lib.MapOutput): - return contains_indexed_slices(value.get()) else: return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py index d25964fa41adc7b1c9164a4ffe49c4c5532f76ac..e46240abbfa3d3618009f8bafe5db66e06e8bbd3 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -98,24 +98,13 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDevice(self): + def testContainsIndexedSlices_PerReplica(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) t1 = math_ops._as_indexed_slices( constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDeviceMapOutput(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({ - "/gpu:0": value_lib.MapOutput([t0]), - "/cpu:0": value_lib.MapOutput([t1])}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_replica)) @combinations.generate(combinations.combine( mode=["graph", "eager"], diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 018512ae5a22eaa7fb78a8c4e5918fec22eb8178..8f82b4c92aa4305af121855972df4947c963850d 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -300,10 +300,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - try: - train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) - except TypeError: - train_distribute = train_distribute_cls(num_gpus_per_worker=2) + train_distribute = train_distribute_cls( + num_gpus_per_worker=context.num_gpus()) if eval_distribute_cls: eval_distribute = eval_distribute_cls( diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index c7036daa3e3321a29e4fc9ae30449fbf15b69b1b..0fd3acd045170c04ebdaa9c84d0cb7267a4bc68a 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -61,7 +61,6 @@ def get_input_datasets(use_bfloat16=False): # train dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() - train_ds = train_ds.shuffle(100) train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) train_ds = train_ds.batch(64, drop_remainder=True) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py index f07ec8234dfe87f2869cd7c2dd6a64c477712d15..ac1ccd64b3267645cbe10fdc02892fd4abd61df1 100644 --- a/tensorflow/contrib/distribute/python/input_ops.py +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import ops @@ -27,9 +28,8 @@ from tensorflow.python.platform import tf_logging # TODO(priyag): Any other reader datasets to consider here? _READER_DATASET_OPS = [ - "TextLineDataset", - "TFRecordDataset", - "FixedLengthRecordDataset" + "TextLineDataset", "TFRecordDataset", "FixedLengthRecordDataset", + "FixedLengthRecordDatasetV2" ] @@ -75,6 +75,8 @@ def auto_shard_dataset(dataset, num_shards, index): # instead of updating in-place. return dataset._clone( filenames=dataset._filenames.shard(num_shards, index)) + elif isinstance(dataset, dataset_ops.RangeDataset): + return dataset.shard(num_shards, index) elif hasattr(dataset, "_map_func"): # TODO(priyag): Make this check more robust by enforcing some common # property on all map/flatmap/interleave datasets. diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py index f4c222f26c3f6501cd78a69dd6a6d9a442a6bd24..46a1cf41c55b371e87979ca625765e0531ac188b 100644 --- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -157,7 +157,7 @@ class MirroredStrategyOptimizerV2Test(test.TestCase): dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): (var, m, v, op, counter) = dist.call_for_each_replica( - create_fn, dist.worker_device_index, run_concurrently=False) + create_fn, args=[dist.worker_device_index]) self.evaluate(variables.global_variables_initializer()) var_val = [2.0, 2.0, 2.0] self.assertAllClose( diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 4cd8ac14100ad3d35701670042faed7502846a31..0db5844e4c40e84c635b063523b95226241d07fb 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -47,7 +47,6 @@ _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 -_TOLERANCE = 1e-5 # TODO(anjalisridhar): Add a decorator that will allow us to run these tests as @@ -213,10 +212,76 @@ def multi_input_output_model(): return model +def get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + if with_distribution: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': 1, + 'shuffle': False, + } + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + # TODO(b/119318587): We should not require batch_size when distribution + # is enabled. + 'batch_size': (len(x_predict) // with_distribution.num_replicas_in_sync + if with_distribution else None), + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(train_dataset, batch_size, with_distribution) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': 1, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + predict_batch_size = len(x_predict) + if with_distribution: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'batch_size': None, + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs + + strategies = [combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy, # steps_per_run=2 combinations.tpu_strategy_one_step] @@ -245,6 +310,13 @@ def strategy_and_optimizer_combinations(): mode=['graph']) +def strategy_and_inputs(): + return combinations.combine( + distribution=strategies, + use_numpy=[True, False], + mode=['graph']) + + class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): @@ -413,8 +485,8 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, with self.assertRaisesRegexp(ValueError, 'is smaller than the number ' 'of replicas'): - # The batch size(32) * num_replicas(3) is 96 which is greater than the - # number of input samples(64). + # The batch size(32) * num_replicas_in_sync(3) is 96 which is greater + # than the number of input samples(64). distributed_training_utils.get_input_batch_params(inputs, 32, strategy) @@ -598,36 +670,33 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(strategy_combinations()) def test_model_interleaved_eval_same_as_direct_eval(self, distribution): with self.cached_session(): - loss = 'mse' - user_controlled_model = get_model() - user_controlled_optimizer = gradient_descent.GradientDescentOptimizer( - 0.001) - user_controlled_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - user_controlled_model.compile(user_controlled_optimizer, loss, - metrics=user_controlled_metrics, - distribute=distribution) + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) interleaved_model = get_model() - interleaved_optimizer = gradient_descent.GradientDescentOptimizer(0.001) - interleaved_metrics = ['mae', keras.metrics.CategoricalAccuracy()] - interleaved_model.compile(interleaved_optimizer, loss, - metrics=interleaved_metrics, - distribute=distribution) + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) dataset = get_dataset(distribution) # Call fit with validation interleaved - interleaved_output = interleaved_model.fit(dataset, epochs=2, - steps_per_epoch=2, verbose=0, - validation_data=dataset, - validation_steps=2) + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) # Manually control the validation running after each epoch. user_controlled_output = [] for _ in range(2): user_controlled_model.fit( - dataset, epochs=1, steps_per_epoch=2, verbose=0) + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) user_controlled_output.append( user_controlled_model.evaluate(dataset, steps=2)) @@ -1019,26 +1088,36 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_replicas + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) self.assertEqual(history.history['binary_accuracy'], [1.0]) - @combinations.generate(strategy_combinations()) - def test_correctness(self, distribution): + @combinations.generate(strategy_and_inputs()) + def test_correctness(self, distribution, use_numpy): with self.cached_session(): + tolerance = 1e-5 + + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + # TODO(b/119257215): use the default one once the flakyness is fixed. + tolerance = 1e-4 + keras.backend.set_image_data_format('channels_last') - num_samples = 10000 np.random.seed(_RANDOM_SEED) random_seed.set_random_seed(_RANDOM_SEED) - # Train and predict datasets are created with the same input numpy arrays. + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 x_train = np.random.rand(num_samples, 1) y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and @@ -1052,49 +1131,38 @@ class TestDistributionStrategyCorrectness(test.TestCase, initial_weights = model.get_weights() def fit_and_predict(with_distribution=None): + # We have initialized the model to the same weight for the distribution + # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), distribute=with_distribution) - batch_size = 64 - if with_distribution: - batch_size //= with_distribution.num_replicas - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, - y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - # We have initialized the model to the same weight for the distribution - # and non-distribution run. If you want to initialize the model to - # random weights for each run, you need to run the model through the - # entire dataset at least once to ensure that the weights converge to - # the same value. - model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, with_distribution, + x_train, y_train, x_predict)) + model.fit(**training_inputs) + eval_result = model.evaluate(**eval_inputs) weights = model.get_weights() - x_predict = [[1.], [2.], [3.], [4.]] - predict_batch_size = 4 - if with_distribution: - predict_batch_size //= with_distribution.num_replicas - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, distribution) - predict_result = model.predict(predict_dataset, steps=1) - - return weights, predict_result - - wts_with_ds, predict_with_ds = fit_and_predict( + predict_result = model.predict(**predict_inputs) + + return weights, eval_result, predict_result + + wts_with_ds, eval_with_ds, predict_with_ds = fit_and_predict( with_distribution=distribution) - wts_without_ds, predict_without_ds = fit_and_predict( + wts_without_ds, eval_without_ds, predict_without_ds = fit_and_predict( with_distribution=None) - # Verify that the weights are the same within some limits of tolerance. + # Verify that the weights, eval results, predict outputs are the same + # within some limits of tolerance. + self.assertAllClose( + wts_with_ds, wts_without_ds, atol=tolerance, rtol=tolerance) self.assertAllClose( - wts_with_ds, wts_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) - # Verify that the predicted outputs are the same within some limits of - # tolerance. + eval_with_ds, eval_without_ds, atol=tolerance, rtol=tolerance) self.assertAllClose( - predict_with_ds, predict_without_ds, atol=_TOLERANCE, rtol=_TOLERANCE) + predict_with_ds, predict_without_ds, atol=tolerance, rtol=tolerance) # TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 9e1a7ad3932e3e8b79c70f1c07a241dcf52564f1..c28ab416518799e239bff43def75e00b7c22ee73 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -100,7 +100,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_replica( - metric_fn, inputs) + metric_fn, args=[inputs]) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) @@ -111,14 +111,14 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): # In each run, we run multiple steps, and each steps consumes as many # batches as number of replicas. batches_per_update = ( - distribution.num_replicas * distribution.steps_per_run) + distribution.num_replicas_in_sync * distribution.steps_per_run) else: value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_replicas" with "1". - batches_per_update = distribution.num_replicas + # replace "distribution.num_replicas_in_sync" with "1". + batches_per_update = distribution.num_replicas_in_sync self.evaluate(iterator.initializer) self.evaluate(distribution.initialize()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 165732d578fd25b1ef631efc5827fd636427c7c8..c6562463edbf8e03d5771a5147dc227ddf438c40 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -22,7 +22,6 @@ from absl.testing import parameterized import numpy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops @@ -67,8 +66,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -111,7 +109,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def run_step(): return distribution.group( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): with self.cached_session() as sess: @@ -162,8 +160,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, *inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -221,7 +218,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, update_ops_in_cross_replica_mode): """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): - num_replicas = len(distribution.worker_devices) + num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_replicas, @@ -229,17 +226,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm=renorm, update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - # Make sure prefetching is disabled since that makes the - # specific input on each device to be non deterministic, and - # this test relies on specific input being on each device. - if isinstance(distribution, mirrored_strategy.MirroredStrategy): - self.assertFalse(distribution._prefetch_on_device) - def step_fn(ctx, *inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_replica( - model_fn, *inputs, run_concurrently=batchnorm.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) if update_ops_in_cross_replica_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) @@ -334,8 +324,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(ctx, x, y): del ctx # Unused return distribution.group( - distribution.call_for_each_replica( - model_fn, x, y, run_concurrently=False)) + distribution.call_for_each_replica(model_fn, args=(x, y))) iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) @@ -369,10 +358,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 # with sum loss reduction, or 10.6 with mean. if loss_reduction == losses_impl.Reduction.SUM: - # Note that the "distribution.num_replicas" factor will go away once - # we split the input across replicas, instead of pulling a complete + # Note that the "distribution.num_replicas_in_sync" factor will go away + # once we split the input across replicas, instead of pulling a complete # batch of input per replica. - self.assertNear(weight, 2 + 21.2 * distribution.num_replicas, 0.0001) + self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync, + 0.0001) else: # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) @@ -420,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def step_fn(output_context, *inputs): (train_op, loss) = distribution.call_for_each_replica( - model_fn, output_context, *inputs, run_concurrently=False) + model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( name="cross_replica_loss_agg", output=loss, @@ -491,7 +481,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def _verify_loss_output(self, initial_loss, loss_output, aggregated, distribution): if not aggregated: - self.assertEqual(distribution.num_replicas, + self.assertEqual(distribution.num_replicas_in_sync, len(distribution.unwrap(loss_output))) loss_output = distribution.reduce( aggregation=variables_lib.VariableAggregation.MEAN, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index c23de0694984076b1c9a8da45219436fc38cd286..2d75024e7a058af60df12cc0048a2c391d80073d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -73,17 +73,14 @@ class _RequestedStop(Exception): # TODO(yuefengz): maybe create a common class for those who need to call this # _call_for_each_replica. -def _call_for_each_replica(distribution, fn, *args, **kwargs): +def _call_for_each_replica(distribution, fn, args, kwargs): """Run `fn` in separate threads, once per replica/worker device. Args: distribution: the DistributionStrategy object. fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. + args: positional arguments for `fn` + kwargs: keyword arguments for `fn`. Returns: Merged return value of `fn` across all replicas. @@ -92,16 +89,12 @@ def _call_for_each_replica(distribution, fn, *args, **kwargs): RuntimeError: If fn() calls get_replica_context().merge_call() a different number of times from the available devices. """ - run_concurrently = kwargs.pop("run_concurrently", True) + # TODO(josh11b): Add this option once we add synchronization to variable + # creation. Until then, this is pretty unsafe to use. + run_concurrently = False if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False # Needed for per-thread device, etc. contexts in graph mode. ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) @@ -192,7 +185,7 @@ def _reduce_non_distributed_value(distribution, aggregation, value, raise ValueError("You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") - # If the same value is present on all replicas then the PerDevice value will + # If the same value is present on all replicas then the PerReplica value will # be a single value. We also handle the case when `value` is a single value # and equal to 0. if value == 0: @@ -348,8 +341,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): specified. cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not set, the `configure` method will try to find the best one. - prefetch_on_device: optional boolean to specify whether to prefetch input - data to devices. auto_shard_dataset: whether to auto-shard the dataset when there are multiple workers. cross_tower_ops: Deprecated alias for `cross_device_ops`. @@ -360,14 +351,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_device_ops=None, - prefetch_on_device=None, auto_shard_dataset=False, cross_tower_ops=None): super(MirroredStrategy, self).__init__() assert not (cross_device_ops and cross_tower_ops) self._cross_tower_ops = cross_device_ops or cross_tower_ops - self._prefetch_on_device = prefetch_on_device self._auto_shard_dataset = auto_shard_dataset # Remember num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: @@ -402,7 +391,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # TODO(josh11b): Require at least 2 devices? self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)}) + self._device_index = values.PerReplica( + {d: i for i, d in enumerate(devices)}) def _initialize_multi_worker(self, num_gpus, cluster_spec): """Initializes the object for multi-worker training.""" @@ -417,19 +407,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if num_gpus is None: raise ValueError("`num_gpus` is required if `cluster_spec` is given.") if num_gpus > 0: - self._worker_device_map = { - worker: [ + self._worker_devices = [ + (worker, [ device_util.canonicalize(worker + "/device:GPU:%d" % gpu) for gpu in range(num_gpus) - ] for worker in self._workers - } + ]) for worker in self._workers + ] else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, "/device:CPU:0")] + self._worker_devices = [ + (worker, [device_util.canonicalize(worker, "/device:CPU:0")]) for worker in self._workers - } + ] - devices = nest.flatten(self._worker_device_map) + devices = nest.flatten([l for _, l in self._worker_devices]) # Setting `_default_device` will add a device scope in the # distribution.scope. We set the default device to the first worker. When @@ -446,7 +436,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # TODO(josh11b): Require at least 2 devices? self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice( + self._device_index = values.PerReplica( {d: i for i, d in enumerate(devices)}) def _create_variable(self, next_creator, *args, **kwargs): @@ -490,12 +480,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def distribute_dataset(self, dataset_fn): if self._cluster_spec: return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device, self._auto_shard_dataset) + partial(self._call_dataset_fn, dataset_fn), self._worker_devices, + auto_shard=self._auto_shard_dataset) else: - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), self._devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, @@ -546,10 +535,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerDevice container. + # container, else in a PerReplica container. if aggregation is variables_lib.VariableAggregation.NONE: last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + {d: t for d, t in zip(self._devices, output)}, values.PerReplica) else: assert len(output) == 1 last_step_tensor_outputs_dict[name] = output[0] @@ -562,23 +551,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._get_cross_tower_ops().broadcast(tensor, destinations or self._devices) - def _call_for_each_replica(self, fn, *args, **kwargs): - return _call_for_each_replica(self, fn, *args, **kwargs) - - def map(self, map_over, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - index = {} - for i, m in enumerate(map_over): - d = self._devices[i % len(self._devices)] - with ops.device(d): - l = index.get(d, []) - l.append(fn(m, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - index[d] = l - # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput - # in addition to PerDevice data. - return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + def _call_for_each_replica(self, fn, args, kwargs): + return _call_for_each_replica(self, fn, args, kwargs) def configure(self, session_config=None, @@ -617,9 +591,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerDevice or Mirrored - # values. For example, the same value could be present on all replicas in - # which case `value` would be a single value or value could be 0. + # This function handles reducing values that are not PerReplica or + # Mirrored values. For example, the same value could be present on all + # replicas in which case `value` would be a single value or value could + # be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_REPLICA: @@ -818,7 +793,7 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): `MirroredStrategy.call_for_each_replica()`). """ - def _merge_call(self, fn, *args, **kwargs): + def _merge_call(self, fn, args, kwargs): """Delegate to the main thread to actually perform merge_call().""" t = threading.current_thread() # a _MirroredReplicaThread t.merge_fn = fn @@ -837,5 +812,9 @@ class MirroredReplicaContext(distribute_lib.ReplicaContext): @property def device(self): + raise RuntimeError("Use .devices instead") + + @property + def devices(self): distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] + return [self._distribution_strategy.worker_devices[self._replica_id]] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index b8e7edaaf82804e6741cc4c94c44ed77189d7ad9..1fd18e09c01d9da8a3c2b865f526cb113d40f530 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -78,11 +78,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self._test_minimize_loss_graph( self._get_distribution_strategy(), soft_placement=soft_placement) - def testMapReduce(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): if not GPU_TEST: self.skipTest("Not GPU test") @@ -120,7 +115,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_replica(run_fn, dist.worker_device_index) + dist.call_for_each_replica(run_fn, args=(dist.worker_device_index,)) @test_util.run_in_graph_and_eager_modes def testReduceToCpu(self): @@ -132,7 +127,8 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + run_fn, args=(dist.worker_device_index,)) reduced = dist.reduce( variable_scope.VariableAggregation.SUM, result, @@ -152,7 +148,8 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): - result = dist.call_for_each_replica(run_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + run_fn, args=(dist.worker_device_index,)) reduced = dist.reduce( variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, result, @@ -207,7 +204,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) self.assertEquals("foo:0", result.name) @@ -225,7 +222,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # Default name of "Variable" will be used. self.assertEquals("Variable:0", result.name) @@ -246,7 +243,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) for i, v in enumerate(result): self.assertIsInstance(v, values.MirroredVariable) self.assertEquals("foo" + str(i) + ":0", v.name) @@ -269,7 +266,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) self.assertEquals(4, len(result)) @@ -293,7 +290,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with dist.scope(): result = dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False) + model_fn, args=(dist.worker_device_index,)) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. self.assertEquals("foo_0:0", result.name) @@ -329,8 +326,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): features = iterator.get_next() with dist.scope(): - result = dist.call_for_each_replica( - model_fn, features, run_concurrently=False) + result = dist.call_for_each_replica(model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) @@ -368,7 +364,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v = variable_scope.variable(1.0, name="var-main0") self.assertEquals("var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) @@ -411,7 +407,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v = variable_scope.get_variable("var-main0", [1]) self.assertEquals("main/var-main0:0", v.name) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) @@ -448,7 +444,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): devices = ["/device:GPU:0", "/device:CPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): - v0, v1 = dist.call_for_each_replica(create_fn, run_concurrently=False) + v0, v1 = dist.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) @@ -465,7 +461,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return update0, update1 update0a, update1a = dist.call_for_each_replica( - update_member_fn, dist.worker_device_index, run_concurrently=False) + update_member_fn, args=(dist.worker_device_index,)) # Update "sync on read" variable. self.evaluate(dist.group(update0a)) @@ -491,7 +487,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return update0, update1 update0b, update1b = dist.call_for_each_replica( - update_state_ops_fn, dist.worker_device_index, run_concurrently=False) + update_state_ops_fn, args=(dist.worker_device_index,)) self.evaluate(dist.group(update0b)) # Update "sync on read" variable. @@ -588,7 +584,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) with dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) self.assertEquals("foo:0", result.name) @@ -611,7 +607,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_replica(model_fn, names, run_concurrently=False) + _ = dist.call_for_each_replica(model_fn, args=(names,)) @test_util.run_in_graph_and_eager_modes(config=config) def testReplicaLocalVariable(self): @@ -652,7 +648,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( dist.call_for_each_replica( - model_fn, dist.worker_device_index, run_concurrently=False)) + model_fn, args=(dist.worker_device_index,))) # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) @@ -709,7 +705,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) @@ -730,7 +726,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) self.assertEquals(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) @@ -760,7 +756,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) @@ -793,7 +789,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(), dist.scope(): with ops.name_scope("main"): a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) @@ -824,7 +820,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with context.graph_mode(), dist.scope(): - result = dist.call_for_each_replica(model_fn, run_concurrently=False) + result = dist.call_for_each_replica(model_fn) # Two variables are created by the RNN layer. self.assertEquals(2, len(result)) for v in result: @@ -851,7 +847,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return var.assign(value) with dist.scope(): - ret_v_sum = dist.call_for_each_replica(model_fn, run_concurrently=False) + ret_v_sum = dist.call_for_each_replica(model_fn) update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) # Initialize variables. @@ -894,7 +890,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -908,7 +904,7 @@ class MirroredVariableUpdateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(config=config) def testAssignMirroredVarReplicaContextWithSum(self): - # Test that we don't reduce a non-per-device value with the "sum" + # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. self._skip_eager_if_gpus_less_than(1) def var_fn(): @@ -920,7 +916,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -942,7 +938,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -960,7 +956,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -971,8 +967,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(0.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -986,7 +981,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -994,8 +989,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(5.0, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1008,7 +1002,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1036,7 +1030,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1047,8 +1041,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(1.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1062,7 +1055,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -1070,8 +1063,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(6.0, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1084,7 +1076,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1104,7 +1096,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1115,8 +1107,7 @@ class MirroredVariableUpdateTest(test.TestCase): mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(4.5, self.evaluate(mirrored_var)) @test_util.run_in_graph_and_eager_modes(config=config) @@ -1130,7 +1121,7 @@ class MirroredVariableUpdateTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - mirrored_var = dist.call_for_each_replica(var_fn, run_concurrently=False) + mirrored_var = dist.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -1138,8 +1129,7 @@ class MirroredVariableUpdateTest(test.TestCase): def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_replica( - model_fn, run_concurrently=False))) + self.evaluate(dist.unwrap(dist.call_for_each_replica(model_fn))) self.assertEquals(4.0, self.evaluate(mirrored_var)) @@ -1211,8 +1201,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + replica_local_var = dist.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) @@ -1243,8 +1232,7 @@ class ReplicaLocalVariableAssignTest(test.TestCase): ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): - replica_local_var = dist.call_for_each_replica(model_fn, - run_concurrently=False) + replica_local_var = dist.call_for_each_replica(model_fn) self.assertTrue(isinstance(replica_local_var, values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) @@ -1307,8 +1295,7 @@ class MirroredStrategyDefunTest(test.TestCase): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_replica(model_fn, mock_model, *inputs, - run_concurrently=False) + result = dist.call_for_each_replica(model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1320,11 +1307,10 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_device_graph_functions = dist.call_for_each_replica( - defun.get_concrete_function, - mock_model, *inputs, run_concurrently=False) + per_replica_graph_functions = dist.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs) for device in devices: - graph_function = per_device_graph_functions.get(device=device) + graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) @@ -1398,16 +1384,16 @@ class MirroredStrategyDefunTest(test.TestCase): two_variables=True) @test_util.run_in_graph_and_eager_modes() - def testPassPerDevice(self): + def testPassPerReplica(self): self._skip_eager_if_gpus_less_than(1) @function.defun def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) + factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) self._call_and_check(fn1, [factors], expected_result, [fn1]) @test_util.run_in_graph_and_eager_modes() @@ -1429,8 +1415,7 @@ class MirroredStrategyDefunTest(test.TestCase): gradients_fn = backprop.implicit_grad(loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = dist.call_for_each_replica( - gradients_fn, None, run_concurrently=False) + grads_and_vars = dist.call_for_each_replica(gradients_fn, args=(None,)) optimizer = gradient_descent.GradientDescentOptimizer(0.25) update_ops = optimizer._distributed_apply(dist, grads_and_vars) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 2bfe0f3e7a66311c9b0673761b73382e477cb24b..bea684e77ca554af8937b72619470550ad4666b1 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -40,9 +40,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): self._test_device_index(self._get_distribution_strategy()) @@ -83,7 +80,8 @@ class VariableCreatorStackTest(test.TestCase): with context.graph_mode(), \ dist.scope(), \ variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_replica(model_fn, dist.worker_device_index) + result = dist.call_for_each_replica( + model_fn, args=(dist.worker_device_index,)) result = dist.unwrap(result) expected = ["main_thread:thread_0", "main_thread:thread_1"] self.assertEquals(expected, result) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py index 815644421e36cc397d6faebf9abd9c54bab557de..7ecc852d20508cc7063f3598c9fef03d6ce536a5 100644 --- a/tensorflow/contrib/distribute/python/moving_averages_test.py +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -93,7 +93,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([10.0, 11.0]) val = constant_op.constant([1.0, 2.0]) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average( var, val, decay, zero_debias=False) @@ -121,7 +122,8 @@ class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): var = variables.Variable([0.0, 0.0]) val = array_ops.placeholder(dtypes.float32) decay = 0.25 - # NOTE(josh11b): We currently generate an error if val is a PerDevice value. + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. assign = moving_averages.assign_moving_average(var, val, decay) variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 8bdf0012087a60cd7d4acfd4eaf0ee0742275655..a0d8f938874bf4098a15b8bb8507c20a56618324 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -25,8 +25,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -40,10 +38,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device, prefetch_on_device=None): + def __init__(self, device): super(OneDeviceStrategy, self).__init__() self._device = device - self._prefetch_on_device = prefetch_on_device self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): @@ -62,9 +59,8 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return next_creator(*args, **kwargs) def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), [self._device], - self._prefetch_on_device) + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), [self._device]) def _broadcast(self, tensor, destinations): del destinations @@ -117,29 +113,13 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): - # We don't run `fn` in multiple threads in OneDeviceStrategy. - kwargs.pop("run_concurrently", None) + def _call_for_each_replica(self, fn, args, kwargs): with ops.device(self._device), _OneDeviceReplicaContext(self): return fn(*args, **kwargs) - def map(self, map_over, fn, *args, **kwargs): - with ops.device(self._device): - return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, aggregation, value, destinations): - del destinations - if not isinstance(value, values.MapOutput): - return value - l = value.get() - assert l - with ops.device(self._device): - if aggregation == vs.VariableAggregation.SUM: - return math_ops.add_n(l) - elif aggregation == vs.VariableAggregation.MEAN: - return math_ops.add_n(l) / len(l) - else: - assert False + del aggregation, destinations + return value def _update(self, var, options, fn, *args, **kwargs): # The implementations of _update() and _update_non_slot() are identical @@ -171,6 +151,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def num_replicas(self): return 1 + @property + def num_replicas_in_sync(self): + return 1 + @property def worker_devices(self): return [self._device] @@ -188,6 +172,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for OneDeviceStrategy.""" def __init__(self, distribution_strategy): distribute_lib.ReplicaContext.__init__( @@ -195,4 +180,8 @@ class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): @property def device(self): - return self._distribution_strategy.worker_devices[0] + raise RuntimeError("Use .devices instead") + + @property + def devices(self): + return [self._distribution_strategy.worker_devices[0]] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 3fb92273924a665bf2a1ee5fc94b75273b8c5f78..95f4cdb7868a09e696d60d4e0ab6cb4592682787 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -35,9 +35,6 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - def testDeviceIndex(self): self._test_device_index(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 0554f4a83bda28142d709020a5a648127d66eab0..fa4705af7cb592119f56686d1f693a156f7b4b13 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def run_step(): return control_flow_ops.group(distribution.unwrap( distribution.call_for_each_replica( - model_fn, iterator.get_next(), run_concurrently=layer.built))) + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 2aa7f1ae5d6f0b3735570c1a5c11ba8c4ce662af..790b37f86010eba6bdc87e6424e55a97629c5d1a 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -64,7 +64,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): Operations that occur only on the first replica (such as incrementing the global step), will occur on the first replica *of every worker*. - It is expected to call `call_for_each_replica(fn, *args, **kwargs)` for any + It is expected to call `call_for_each_replica(fn, ...)` for any operations which potentially can be replicated across replicas (i.e. multiple GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra caution needs to be taken: @@ -223,7 +223,7 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) def _broadcast(self, tensor, destinations): @@ -231,10 +231,13 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): destinations = self._compute_devices return self._cross_tower_ops.broadcast(tensor, destinations) + def _allow_variable_partition(self): + return not context.executing_eagerly() + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_replicas > 1: + if self.num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, @@ -288,9 +291,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): with ops.device(self._variable_device): return var_creator(*args, **kwargs) - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_replica(self, fn, *args, **kwargs) + return mirrored_strategy._call_for_each_replica(self, fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): if not self._cluster_spec: @@ -336,9 +339,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "You cannot update variable with a Mirrored object with multiple " "components %r when using ParameterServerStrategy. You must " "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerDevice): + elif isinstance(x, values.PerReplica): raise ValueError( - "You cannot update variable with a PerDevice object %r when using " + "You cannot update variable with a PerReplica object %r when using " "ParameterServerStrategy. You must specify a single value or a " "Mirrored with a single value" % x) else: diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index a9f643c6eccdaee86b7ea531a82d257884fa92b9..81a23c89030221a8a15bdedc796c50d9c518138c 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -37,6 +37,7 @@ from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -85,8 +86,7 @@ class ParameterServerStrategyTestBase( config=sess_config) as sess, \ d.scope(): - # Define a variable outside the call_for_each_replica scope. This is not - # recommended. + # Define a variable outside the call_for_each_replica scope. n = variable_scope.get_variable('n', initializer=10.0) self.assertEqual(n.device, '/job:ps/task:0') @@ -178,6 +178,75 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) + def _test_device_assignment_distributed_enable_partitioner( + self, task_type, task_id, num_gpus): + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + num_shards = len(d.parameter_devices) + partitioner = partitioned_variables.fixed_size_partitioner(num_shards) + with ops.Graph().as_default(), \ + self.cached_session(target=self._default_target, + config=sess_config) as sess, \ + d.scope(): + + n = variable_scope.get_variable( + 'n', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + + for part_id, var in enumerate(n): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + + def model_fn(): + a = constant_op.constant([3.0, 5.0]) + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable( + 'x', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + x_add = x.assign_add(a, name='x_add') + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + for part_id, var in enumerate(x): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + self.assertEqual(var.device, x_add[part_id].device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x_add[0]): + y = variable_scope.get_variable( + 'y', + initializer=constant_op.constant([20.0, 10.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + y_add = y.assign_add( + [array_ops.identity(x_add[0]), + array_ops.identity(x_add[1])]) + + for part_id, var in enumerate(y): + self.assertEqual(var.device, '/job:ps/task:0') + self.assertEqual(y_add[part_id].device, var.device) + self.assertEqual(var.device, x_add[0].device) + + return x_add, y_add + + x, y = d.call_for_each_replica(model_fn) + + if context.num_gpus() >= 1: + variables.global_variables_initializer().run() + x_val, y_val = sess.run([x, y]) + if num_gpus < 1: + self.assertEqual(x_val, [13.0, 25.0]) + self.assertEqual(y_val, [33.0, 35.0]) + else: + x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] + y_expect = [ + 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus + ] + self.assertEqual(x_val, x_expect) + self.assertEqual(y_val, y_expect) + def _test_device_assignment_local(self, d, compute_device='CPU', @@ -345,11 +414,11 @@ class ParameterServerStrategyTestBase( self._finish_condition.release() x_val, y_val, z_val = sess.run([x, y, z]) - self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas) - self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) self.assertEqual(z_val, 30.0 + 1.0 * num_workers) - return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas and - y_val == 20.0 + 1.0 * num_workers * d.num_replicas and + return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas_in_sync and + y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): @@ -394,7 +463,7 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -479,6 +548,12 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testDeviceAssignmentDistributed(self, num_gpus): self._test_device_assignment_distributed('worker', 1, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + self._test_device_assignment_distributed_enable_partitioner( + 'worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, self._cluster_spec, context.num_gpus()) diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index a5adaac47ceb3e22909bb852c6e3418446710a51..3dc815f0371002bd3a8657f18ccc09a27bb14961 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -90,7 +90,6 @@ class StandardSingleLossStep(StandardInputStep): super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer - self._is_run_concurrently = False self._iterations_per_step = iterations_per_step def __call__(self): @@ -101,14 +100,11 @@ class StandardSingleLossStep(StandardInputStep): gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) grads_and_vars = self.distribution.call_for_each_replica( - gradients_fn, - ctx, *inputs, - run_concurrently=self._is_run_concurrently) + gradients_fn, args=(ctx,) + inputs) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be # created. - self._is_run_concurrently = True return self._optimizer._distributed_apply( # pylint: disable=protected-access self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 60ef0a2106a17d2eede6acec9b9178d1a9d736ff..3c0c10430ebc445d1e2303d83ea28f9cf731fb71 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -104,7 +104,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one, run_concurrently=l.built) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -160,7 +160,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_replica(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -189,15 +189,6 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_map_reduce(self, d, in_graph=None): - with d.scope(): - map_in = [constant_op.constant(i) for i in range(10)] - map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out, - "/device:CPU:0") - expected = 90 # 2 * (0 + 1 + ... + 9) - self.assertEqual(expected, observed.numpy()) - def _test_device_index(self, d): with d.scope(): expected_devices = [False] * len(d.worker_devices) @@ -207,7 +198,7 @@ class DistributionTestBase(test.TestCase): self.assertFalse(expected_devices[device_id]) expected_devices[device_id] = True - d.call_for_each_replica(mark_devices_fn, d.worker_device_index) + d.call_for_each_replica(mark_devices_fn, args=(d.worker_device_index,)) self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) def _test_replica_id(self, d): diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 65ef21df09ba34c274cdce73996bff7b9c32da85..f5b4531ba8c483e69f2a2b5539b27205efb9fc21 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -141,7 +141,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # parallelism. device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name} - self._device_index = values.PerDevice(device_map) + self._device_index = values.PerReplica(device_map) self._host_device = self.get_host_cpu_device(0) self._tpu_devices = sorted(device_map.keys()) # Only create variables for the number of replicas we're running. @@ -215,12 +215,12 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return enqueue_op_per_host def distribute_dataset(self, dataset_fn): - worker_map = { - self.get_host(hid): [self.get_host_cpu_device(hid)] + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) - } + ] return values.MultiWorkerDataset( - functools.partial(self._call_dataset_fn, dataset_fn), worker_map) + functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have @@ -308,7 +308,8 @@ class TPUStrategy(distribute_lib.DistributionStrategy): # For outputs that have already been aggregated, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. + # TODO(josh11b): If aggregation is NONE, we should return a PerReplica + # value. if aggregation is not variables_lib.VariableAggregation.NONE: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] @@ -316,10 +317,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return ctx - def _call_for_each_replica(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # TODO(jhseu): Consider making it so call_for_each_replica implies that # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - kwargs.pop("run_concurrently", None) with _TPUReplicaContext(self): return fn(*args, **kwargs) @@ -445,7 +445,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should - # be represented using a PerDevice wrapper instead of a list with + # be represented using a PerReplica wrapper instead of a list with # one entry per device. return val return [val] @@ -544,5 +544,9 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): @property def device(self): + raise RuntimeError("Use .devices instead") + + @property + def devices(self): distribute_lib.require_replica_context(self) - return self._distribution_strategy.worker_devices[self._replica_id] + return [self._distribution_strategy.worker_devices[self._replica_id]] diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 42fb92014a08001d9ed2b6833dac6b1b4efad434..a1629735353d87f9dcba6553399003056c84ed86 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -51,7 +51,7 @@ from tensorflow.python.util import nest # TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): - """Holds a map from device to values. Either PerDevice or Mirrored.""" + """Holds a map from device to values. Either PerReplica or Mirrored.""" def __init__(self, index): self._index = {device_util.canonicalize(key): value @@ -62,7 +62,8 @@ class DistributedValues(object): if device is None: replica_context = distribution_strategy_context.get_replica_context() if replica_context: - device = replica_context.device + # TODO(josh11b): support model parallelism better here + device = replica_context.devices[0] else: device = distribute_lib.get_update_device() if device is None: @@ -75,10 +76,6 @@ class DistributedValues(object): ValueError("Device %s not found in %s (current device %s)" % (device, self._index.keys(), device_util.current())), e) - def on_device(self, device): - device = device_util.canonicalize(device) - return device in self._index - @property def devices(self): return list(self._index.keys()) @@ -167,12 +164,12 @@ class DistributedDelegate(DistributedValues): # TODO(josh11b): Even more operator overloads. -class PerDevice(DistributedValues): +class PerReplica(DistributedValues): """Holds a map from device to unsynchronized values.""" pass -# Note that unlike PerDevice, Mirrored values inherit from +# Note that unlike PerReplica, Mirrored values inherit from # DistributedDelegate and so can be used directly in cross-replica mode. class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" @@ -482,7 +479,8 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): if device is None: replica_context = distribution_strategy_context.get_replica_context() if replica_context: - device = replica_context.device + # TODO(josh11b): support model parallelism better here + device = replica_context.devices[0] else: device = distribute_lib.get_update_device() if device is None: @@ -583,7 +581,8 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. def _assign_func(self, *args, **kwargs): - if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": + strategy = distribution_strategy_context.get_distribution_strategy() + if strategy.__class__.__name__ != "TPUStrategy": raise ValueError("You may only assign to a TPUMirroredVariable within a " "TPUStrategy.") f = kwargs.pop("f") @@ -776,6 +775,18 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase): def op(self): return self._primary_var.op + # pylint: disable=protected-access + @property + def _save_slice_info(self): + return self._primary_var._save_slice_info + + def _get_save_slice_info(self): + return self._primary_var._get_save_slice_info() + + def _set_save_slice_info(self, save_slice_info): + return self._primary_var._set_save_slice_info(save_slice_info) + # pylint: enable=protected-access + @property def _in_graph_mode(self): return self._primary_var._in_graph_mode # pylint: disable=protected-access @@ -861,7 +872,7 @@ def _assert_replica_context(): "Replica-local variables may only be assigned in a replica context.") -class ReplicaLocalVariable(DistributedVariable, PerDevice, +class ReplicaLocalVariable(DistributedVariable, PerReplica, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -942,9 +953,9 @@ def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) -def regroup(per_device, wrap_class=PerDevice): - """Makes device->nest map into a nest of PerDevice/Mirrored values.""" - items = list(per_device.items()) +def regroup(per_replica, wrap_class=PerReplica): + """Makes device->nest map into a nest of PerReplica/Mirrored values.""" + items = list(per_replica.items()) assert items v0 = items[0][1] # First value @@ -1005,7 +1016,7 @@ def regroup(per_device, wrap_class=PerDevice): # want to return the containing MirroredVariable, after a bunch of # sanity checking. In particular, each component should have the # same container, and the devices of the variables should match the - # keys of the per-device dictionary. + # keys of the per-replica dictionary. if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( @@ -1021,11 +1032,11 @@ def regroup(per_device, wrap_class=PerDevice): return distributed_container # pylint: enable=protected-access - return wrap_class(per_device) + return wrap_class(per_replica) def select_device(device, structured): - """Specialize a nest of regular & per-device values for one device.""" + """Specialize a nest of regular & per-replica values for one device.""" def _get(x): return x.get(device) if isinstance(x, DistributedValues) else x @@ -1070,8 +1081,8 @@ def update_regroup(strategy, updates, should_group): return nest.pack_sequence_as(regrouped, grouped_flat) -class PerDeviceDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" +class PerReplicaDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerReplicaDataset`.""" def __init__(self, iterator, devices, prefetch_on_device=None): self._iterator = iterator @@ -1114,8 +1125,8 @@ class PerDeviceDataIterator(object): return self._iterator.output_types -class PerDeviceDataset(object): - """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" +class PerReplicaDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerReplica` data.""" def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices @@ -1136,20 +1147,20 @@ class PerDeviceDataset(object): self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): - """Get a one time use iterator for the distributed PerDeviceDataset.""" + """Get a one time use iterator for the distributed PerReplicaDataset.""" # Graph mode with one shot iterator is disabled. if not context.executing_eagerly(): raise ValueError("Cannot create a one shot iterator. Please use " "`make_initializable_iterator()` instead.") # Eager mode prefetching would error out in constructor. Only remaining # case is non-prefetching in eager mode. We delegate to - # PerDeviceDataIterator to handle that case. + # PerReplicaDataIterator to handle that case. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( + return PerReplicaDataIterator( dataset_iterator, self._devices, prefetch_on_device=False) def make_initializable_iterator(self): - """Get an initializable iterator for the distributed PerDeviceDataset.""" + """Get an initializable iterator for the distributed PerReplicaDataset.""" # Eager mode generates already initialized iterators. Hence we cannot create # an initializable iterator. if context.executing_eagerly(): @@ -1160,7 +1171,7 @@ class PerDeviceDataset(object): self._dataset, self._devices) else: dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( + return PerReplicaDataIterator( dataset_iterator, self._devices, prefetch_on_device=self._prefetch_on_device) @@ -1169,43 +1180,47 @@ class PerDeviceDataset(object): class MultiWorkerDataIterator(object): """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" - def __init__(self, iterators, worker_device_map): + def __init__(self, iterators, worker_device_pairs): """Initialize the MultiWorkerDataIterator object. Args: - iterators: a dict mapping from each worker to an iterator for - that worker. - worker_device_map: a dict mapping from each worker's devices to a list of - devices that belong to this worker. + iterators: a list of worker, iterator pairs. + worker_device_pairs: a list of (worker's devices, a list of + devices that belong to this worker) pairs. Raises: - ValueError: if iterators and worker_device_map are not compatible. + ValueError: if iterators and worker_device_pairs are not compatible. """ - self._iterators = iterators - self._worker_device_map = worker_device_map - if set(self._iterators) != set(self._worker_device_map): - raise ValueError("iterators and worker_device_map are not compatible.") + if [d for d, _ in iterators] != [d for d, _ in worker_device_pairs]: + raise ValueError("iterators and worker_device_pairs are not compatible.") + self._workers = [d for d, _ in iterators] + self._iterators = [i for _, i in iterators] + self._worker_devices = [l for _, l in worker_device_pairs] @property def initializer(self): return control_flow_ops.group( - [iterator.initializer for iterator in self._iterators.values()]) + [iterator.initializer for iterator in self._iterators]) def get_iterator(self, worker): - return self._iterators.get(worker) + for i, w in enumerate(self._workers): + if worker == w: + return self._iterators[i] + return None @property def output_shapes(self): - return self._iterators.values()[0].output_shapes + return self._iterators[0].output_shapes @property def output_types(self): - return self._iterators.values()[0].output_types + return self._iterators[0].output_types def get_next(self, name=None): """Scatter the input across hosts and devices.""" index = {} - for worker, iterator in six.iteritems(self._iterators): + worker_info = zip(self._workers, self._iterators, self._worker_devices) + for worker, iterator, worker_devices in worker_info: if name is not None: d = tf_device.DeviceSpec.from_string(worker) new_name = "%s_%s_%d" % (name, d.job, d.task) @@ -1214,13 +1229,12 @@ class MultiWorkerDataIterator(object): with ops.device(worker): data_per_worker = iterator.get_next(name=new_name) - worker_devices = self._worker_device_map[worker] - # Ungroup these per-device value so as to get a flat map from devices to + # Ungroup these per-replica value so as to get a flat map from devices to # values. for d in worker_devices: v = select_device(d, data_per_worker) if d in index: - raise ValueError("Duplicated devices in worker_device_map: %r" % v) + raise ValueError("Duplicated devices in worker_device_pairs: %r" % v) index[d] = v return regroup(index) @@ -1229,153 +1243,48 @@ class MultiWorkerDataIterator(object): class MultiWorkerDataset(object): """Like a `tf.data.Dataset` that distributes data to different workers. - Each worker gets one shard of the input dataset. It is currently not working - in - eager mode. + Each worker gets one shard of the input dataset. This currently does not work + in eager mode. """ - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, + def __init__(self, dataset_fn, worker_device_pairs, prefetch_on_device=None, auto_shard=False): """Initialize the MultiWorkerDataset object. Args: dataset_fn: a function that returns a `tf.data.Dataset`. - worker_device_map: a dict mapping from each worker to a list of devices - that belong to this worker. + worker_device_pairs: a list of (worker, list of devices on that worker) + pairs. prefetch_on_device: whether to prefetch to devices. auto_shard: whether to auto-shard the dataset. """ - self._worker_device_map = worker_device_map - self._datasets = {} + self._worker_device_pairs = worker_device_pairs + self._datasets = [] # TODO(yuefengz, priyag): support different set of jobs for input # processing. - for i, (worker, worker_devices) in enumerate( - six.iteritems(worker_device_map)): + for i, (worker, worker_devices) in enumerate(worker_device_pairs): with ops.device(worker): worker_input = dataset_fn() if auto_shard: worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) - self._datasets[worker] = PerDeviceDataset( + worker_input, len(worker_device_pairs), i) + dataset = PerReplicaDataset( worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + self._datasets.append((worker, dataset)) def make_one_shot_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): + iterators = [] + for worker, dataset in self._datasets: with ops.device(worker): - iterators[worker] = dataset.make_one_shot_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) + iterators.append((worker, dataset.make_one_shot_iterator())) + return MultiWorkerDataIterator(iterators, self._worker_device_pairs) def make_initializable_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): + iterators = [] + for worker, dataset in self._datasets: with ops.device(worker): - iterators[worker] = dataset.make_initializable_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) - - -class _PerKey(object): - """Holds data associated by keys.""" - - def __init__(self, *index): - # pylint: disable=protected-access - self._index = list(index) - - def get(self, iteration): - return array_ops.gather(self._index, iteration) - - def get_shape(self): - return self._index[-1][-1].get_shape() - - def get_dtype(self): - return self._index[-1][-1].dtype - - def __str__(self): - return "%s:%s" % (self.__class__.__name__, self._index) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._index) - - -class PerIteration(_PerKey): - """Holds input for multiple iterations at once.""" - - def __init__(self, *index): - # pylint: disable=protected-access - super(PerIteration, self).__init__(*[batch._index for batch in index]) - - -class Batches(_PerKey): - pass - - -class MultiIterator(object): - """Iterator that returns results of multiple get_next()s.""" - - def __init__(self, dataset_iterator, iterations, batches_per_iteration): - self._dataset_iterator = dataset_iterator - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def get_next(self, name=None): - """Return PerIteration with `iterations x batches_per_iteration` inputs.""" - data = [] - for _ in range(self._batches_per_iteration): - batch = [] - for _ in range(self._iterations): - batch.append(self._dataset_iterator.get_next(name=name)) - data.append(batch) - - # Here is an example. Suppose each get_next returns a tuple of two tensors. - # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: - # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] - # - # After the first `map_structure` it gets transformed to: - # [(Batches(a, A), Batches(z, Z)), - # (Batches(b, B), Batches(y, Y)), - # (Batches(c, C), Batches(x, X))] - # - # After the second `map_structure` it gets transformed to a tuple of: - # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), - # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) - - data = nest.map_structure(Batches, *data) - data = nest.map_structure(PerIteration, *data) - - return data - - @property - def initializer(self): - return self._dataset_iterator.initializer - - -class PerIterationDataset(object): - """A dataset that returns MultiIterators.""" - - def __init__(self, dataset, iterations, batches_per_iteration): - self._dataset = dataset - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def make_one_shot_iterator(self): - iterator = self._dataset.make_one_shot_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - def make_initializable_iterator(self): - iterator = self._dataset.make_initializable_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - -class MapOutput(object): - """Map can result in multiple outputs per device.""" - - def __init__(self, l): - self._l = l - - def get(self): - return self._l + iterators.append((worker, dataset.make_initializable_iterator())) + return MultiWorkerDataIterator(iterators, self._worker_device_pairs) class MultiStepContext(object): @@ -1430,13 +1339,13 @@ class MultiStepContext(object): output: The tensors that should be outputted with `name`. See below for actual types supported. aggregation: Aggregation method to use to aggregate outputs from multiple - replicas. Required if `set_last_step_output` is called in a replica context. - Optional in cross_replica_context. + replicas. Required if `set_last_step_output` is called in a replica + context. Optional in cross_replica_context. When present, the outputs from all the replicas are aggregated using the current distribution strategy's `reduce` method. Hence, the type of `output` must be what's supported by the corresponding `reduce` method. For e.g. if using MirroredStrategy and aggregation is set, output - must be a `PerDevice` value. + must be a `PerReplica` value. The aggregation method is also recorded in a dictionary `_last_step_outputs_aggregations` for later interpreting of the outputs as already reduced or not. @@ -1482,7 +1391,7 @@ class MultiStepContext(object): def value_container(val): - """Returns the container that this per-device `value` belongs to. + """Returns the container that this per-replica `value` belongs to. Args: val: A value returned by `call_for_each_replica()` or a variable @@ -1528,8 +1437,8 @@ class AggregatingVariable(checkpointable.CheckpointableBase): # We are calling an assign function in an update context. return f(self._v, *args, **kwargs) - # We are calling an assign function in cross replica context, wrap it in an - # update call. + # We are calling an assign function in cross replica context, wrap it in + # an update call. return distribution_strategy_context.get_distribution_strategy().update( self, f, *args, **kwargs) else: diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index d514e6f4c158d15665a2cd46be0547178da66544..268393ee801b5f25bb5a7f061960b817c2d2ce5e 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os from tensorflow.contrib.distribute.python import mirrored_strategy @@ -190,10 +189,10 @@ def _make_mirrored(): class RegroupAndSelectDeviceTest(test.TestCase): - def _is_per_device(self, result, expected, klass=values.PerDevice): + def _is_per_replica(self, result, expected, klass=values.PerReplica): self.assertIsInstance(result, klass) # We canonicalize the devices to match the device strings returned - # by PerDevice, which also does device string canonicalization. + # by PerReplica, which also does device string canonicalization. devices = [device_util.canonicalize(_device_str(i)) for i in range(len(expected))] self.assertEqual(set(devices), set(result.devices)) @@ -206,18 +205,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"]) - self._is_per_device(result[2], ["h1", "h2"]) + self._is_per_replica(result[0], ["a1", "a2"]) + self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"]) - self._is_per_device(result[1][2], ["g1", "g2"]) + self._is_per_replica(result[1][0], ["b1", "b2"]) + self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"]) - self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -238,18 +237,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) - self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) - self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) - self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -275,7 +274,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): ("b", foo)}) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) - self._is_per_device(result[0], ["a", "b"]) + self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_device(), should undo the merge done by regroup(). @@ -341,53 +340,30 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -class PerDeviceDatasetTest(test.TestCase): +class PerReplicaDatasetTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True - def _test_iterator_no_prefetch(self, devices, dataset, expected_values): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) + def _test_iterator(self, devices, dataset, expected_values): + per_replica_dataset = values.PerReplicaDataset(dataset, devices) if context.executing_eagerly(): - iterator = per_device_dataset.make_one_shot_iterator() + iterator = per_replica_dataset.make_one_shot_iterator() else: - iterator = per_device_dataset.make_initializable_iterator() + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() - actual = self.evaluate([ - values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate([ values.select_device(d, next_element) for d in devices]) - def _test_iterator_with_prefetch(self, devices, dataset, expected_values): - if not context.executing_eagerly(): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - def _test_iterator(self, devices, dataset, expected_values): - self._test_iterator_no_prefetch(devices, dataset, expected_values) - self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] @@ -442,9 +418,8 @@ class PerDeviceDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10,))) - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_initializable_iterator() + per_replica_dataset = values.PerReplicaDataset(dataset, devices) + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next() @@ -463,7 +438,7 @@ class PerDeviceDatasetTest(test.TestCase): class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - def _test_iterator(self, iterator, devices, expected_values): + def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() for device in devices: v = values.select_device(device, next_element) @@ -472,73 +447,79 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): self.assertTrue(element.device in device) for expected_value in expected_values: - actual = self.evaluate( + actual = sess.run( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): - self.evaluate([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_device(d, next_element) for d in devices]) - def _test_dataset(self, dataset_fn, worker_device_map, devices, - expected_values): + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values, auto_shard=True): multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) - multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() - self._test_iterator(multi_worker_iterator, devices, expected_values) + dataset_fn, worker_devices, auto_shard=auto_shard) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) def _cpu_devices(self): - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])]) + worker_devices = [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def _cpu_and_one_gpu_devices(self): - # The worker_device_map doesn't have to be a OrderDict object, this is just - # to simplify the testing so that we can pass expected values as a list - # instead of a dict. - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", [ + worker_devices = [ + ("/job:worker/replica:0/task:0", [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0" - ]), ("/job:worker/replica:0/task:1", [ + ]), + ("/job:worker/replica:0/task:1", [ "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" - ])]) + ]) + ] devices = [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def testDataDistributionOneDevicePerWorker(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) + def testDataDistributionNoAutoShard(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_dataset(dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], + auto_shard=False) + def testDataDistributionTwoDevicePerWorker(self): - self.skipTest("Temporarily disabled.") if context.num_gpus() < 1: self.skipTest("A GPU is not available for this test.") - worker_device_map, devices = self._cpu_and_one_gpu_devices() + worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 2, 1, 3], [4, 6, 5, 7]]) def testTupleDataset(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): @@ -550,41 +531,38 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): expected_values = [ [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) ] - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, expected_values) def testInitializableIterator(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() - with context.graph_mode(): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) # After re-initializing the iterator, should be able to iterate again. - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) def testValueErrorForIterator(self): - self.skipTest("Temporarily disabled.") # Incompatiable arguments. with self.assertRaises(ValueError): values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) # Test duplicated devices under same worker. - worker_device_map, _ = self._cpu_devices() - worker_device_map["/job:worker/replica:0/task:0"].append( - "/job:worker/replica:0/task:0/device:CPU:0") + worker_devices, _ = self._cpu_devices() + worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.assertRaises(ValueError): multi_worker_iterator.get_next() diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py index 6e775afb69af04331186fbd3ce963e0b286e14f4..67ffb939663358b5e356b3b626978db959c1bac9 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_v2.py @@ -544,4 +544,19 @@ class SequenceNumericColumn( return fc.SequenceDenseColumn.TensorSequenceLengthPair( dense_tensor=dense_tensor, sequence_length=seq_length) + # TODO(b/119409767): Implement parents, _{get,from}_config. + @property + def parents(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + def _get_config(self): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + + @classmethod + def _from_config(cls, config, custom_objects=None, columns_by_name=None): + """See 'FeatureColumn` base class.""" + raise NotImplementedError() + # pylint: enable=protected-access diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index cfdc7df7d8fd4c1406bf447a79038ac33b11e047..00e04b83ac45a83e54eee7a6e4e146fb683c3d98 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -44,17 +44,18 @@ class ExperimentalTest(test.TestCase): # Assert function docs are properly updated. self.assertEqual("_fn", _fn.__name__) - self.assertEqual("fn doc. (experimental)" - "\n" - "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " - "be removed at any time, and without warning." - "\n" - "\nArgs:" - "\n arg0: Arg 0." - "\n arg1: Arg 1." - "\n" - "\nReturns:" - "\n Sum of args.", _fn.__doc__) + self.assertEqual( + "fn doc. (experimental)" + "\n" + "\nWarning: THIS FUNCTION IS EXPERIMENTAL. It may change " + "or be removed at any time, and without warning." + "\n" + "\nArgs:" + "\n arg0: Arg 0." + "\n arg1: Arg 1." + "\n" + "\nReturns:" + "\n Sum of args.", _fn.__doc__) # Assert calling new fn issues log warning. self.assertEqual(3, _fn(1, 2)) diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index d8584e4e6b7470472b0e1911b9e34c8c80a42e0f..b3f48ec1dd9c75055f4e1ea76eb203b6ccf94718 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -52,9 +52,10 @@ Status GdrServer::Init() { [this](const WorkerEnv* env) { return new GdrRendezvousMgr(env, remote_memory_manager_.get()); }; - WorkerCreationFunction worker_func = [this](WorkerEnv* env) { + WorkerCreationFunction worker_func = [this](WorkerEnv* env, + const ConfigProto& config) { return std::unique_ptr( - new GdrWorker(env, remote_memory_manager_.get())); + new GdrWorker(env, config, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR(remote_memory_manager_->Init()); diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index ce1d8d2d73000559f03046aceacb169890ecc1b6..867cb83f42034c8e9061e333ea671457745f92c3 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -39,9 +39,9 @@ limitations under the License. namespace tensorflow { -GdrWorker::GdrWorker(WorkerEnv* worker_env, +GdrWorker::GdrWorker(WorkerEnv* worker_env, const ConfigProto& config, RemoteMemoryManager* remote_memory_manager) - : GrpcWorker(worker_env), + : GrpcWorker(worker_env, config), remote_memory_manager_(remote_memory_manager), recv_tensor_recent_request_ids_(100000) {} diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index 65105ed997300aa77202301cdd8dddacb0309880..39f11e6bde5a1ca7ae91ead02279d22d70af027b 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -25,7 +25,8 @@ namespace tensorflow { class GdrWorker : public GrpcWorker { public: - GdrWorker(WorkerEnv* env, RemoteMemoryManager* remote_memory_manager); + GdrWorker(WorkerEnv* env, const ConfigProto& config, + RemoteMemoryManager* remote_memory_manager); // Serve the RecvTensorRequest but omit the tensor content and transmit it // out-of-band using GPU Direct RDMA whenever possible. diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD index 9393b702d11a2ef84586f712d30c26fe2a8972bb..2698b83a56a1121fa30f5b05ffa027b4dfd4ba95 100644 --- a/tensorflow/contrib/ignite/BUILD +++ b/tensorflow/contrib/ignite/BUILD @@ -22,48 +22,92 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + ":igfs_ops", ], ) tf_custom_op_library( - name = "_dataset_ops.so", - srcs = ["ops/dataset_ops.cc"], - deps = [":dataset_kernels"], + name = "_ignite_ops.so", + srcs = [ + "kernels/igfs/igfs.h", + "ops/dataset_ops.cc", + "ops/igfs_ops.cc", + ], + deps = [ + ":dataset_kernels", + ":igfs_kernels", + ], ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) +tf_gen_op_libs( + op_lib_names = ["igfs_ops"], + deps = [":igfs_kernels"], +) + cc_library( - name = "dataset_kernels", + name = "ignite_client", srcs = [ - "kernels/ignite_dataset_ops.cc", - "kernels/ignite_client.h", - "kernels/ignite_byte_swapper.h", - "kernels/ignite_plain_client.h", - "kernels/ignite_ssl_wrapper.h", - "kernels/ignite_ssl_wrapper.cc", - "kernels/ignite_binary_object_parser.h", - "kernels/ignite_binary_object_parser.cc", - "kernels/ignite_dataset.h", - "kernels/ignite_dataset.cc", - "kernels/ignite_dataset_iterator.h", - "kernels/ignite_dataset_iterator.cc", + "kernels/client/ignite_client.h", + "kernels/client/ignite_byte_swapper.h", + "kernels/client/ignite_plain_client.h", + "kernels/client/ignite_ssl_wrapper.h", + "kernels/client/ignite_ssl_wrapper.cc", ] + if_not_windows([ - "kernels/ignite_plain_client_unix.cc", + "kernels/client/ignite_plain_client_unix.cc", ]) + if_windows([ - "kernels/ignite_plain_client_windows.cc", + "kernels/client/ignite_plain_client_windows.cc", ]), copts = if_windows([ "-DWIN32_LEAN_AND_MEAN", ]), deps = [ "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", "@boringssl//:ssl", "@protobuf_archive//:protobuf_headers", ], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/dataset/ignite_binary_object_parser.cc", + "kernels/dataset/ignite_binary_object_parser.h", + "kernels/dataset/ignite_dataset.cc", + "kernels/dataset/ignite_dataset.h", + "kernels/dataset/ignite_dataset_iterator.cc", + "kernels/dataset/ignite_dataset_iterator.h", + "kernels/dataset/ignite_dataset_ops.cc", + ], + deps = [ + ":ignite_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "igfs_kernels", + srcs = [ + "kernels/igfs/igfs.cc", + "kernels/igfs/igfs.h", + "kernels/igfs/igfs_client.cc", + "kernels/igfs/igfs_client.h", + "kernels/igfs/igfs_extended_tcp_client.cc", + "kernels/igfs/igfs_extended_tcp_client.h", + "kernels/igfs/igfs_messages.cc", + "kernels/igfs/igfs_messages.h", + "kernels/igfs/igfs_random_access_file.cc", + "kernels/igfs/igfs_random_access_file.h", + "kernels/igfs/igfs_writable_file.cc", + "kernels/igfs/igfs_writable_file.h", + ], + deps = [":ignite_client"], alwayslink = 1, ) @@ -82,10 +126,29 @@ py_library( ], ) +py_library( + name = "igfs_ops", + srcs = [ + "python/ops/igfs_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":igfs_op_loader", + "//tensorflow/python:util", + "//tensorflow/python/data/util:nest", + ], +) + tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "python/ops/gen_dataset_ops.py", - deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"], + deps = [":dataset_ops_op_lib"], +) + +tf_gen_op_wrapper_py( + name = "gen_igfs_ops", + out = "python/ops/gen_igfs_ops.py", + deps = [":igfs_ops_op_lib"], ) tf_kernel_library( @@ -97,13 +160,22 @@ tf_kernel_library( alwayslink = 1, ) +tf_kernel_library( + name = "igfs_ops_kernels", + deps = [ + ":igfs_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + tf_custom_op_py_library( name = "ignite_op_loader", srcs = ["python/ops/ignite_op_loader.py"], - dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"], + dso = [":_ignite_ops.so"], kernels = [ ":dataset_ops_kernels", - "//tensorflow/contrib/ignite:dataset_ops_op_lib", + ":dataset_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ @@ -113,6 +185,22 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "igfs_op_loader", + srcs = ["python/ops/igfs_op_loader.py"], + dso = [":_ignite_ops.so"], + kernels = [ + ":igfs_ops_kernels", + ":igfs_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_igfs_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + # The Apache Ignite servers have to setup before the test and tear down # after the test manually. The docker engine has to be installed. # @@ -122,8 +210,11 @@ tf_custom_op_py_library( # To tear down Apache Ignite servers: # $ bash ./python/tests/stop_ignite.sh tf_py_test( - name = "ignite_dataset_test", - srcs = ["python/tests/ignite_dataset_test.py"], + name = "ignite_test", + srcs = [ + "python/tests/igfs_test.py", + "python/tests/ignite_dataset_test.py", + ], additional_deps = [ ":ignite", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md index 55c89d27996318dabb29bb15372411005301ebd9..c7db0b77e25668fb8a42d204776044420f403e44 100644 --- a/tensorflow/contrib/ignite/README.md +++ b/tensorflow/contrib/ignite/README.md @@ -1,19 +1,32 @@ -# Ignite Dataset - -- [Overview](#overview) -- [Features](#features) - * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) - * [Structured Objects](#structured-objects) - * [Distributed Training](#distributed-training) - * [SSL Connection](#ssl-connection) - * [Windows Support](#windows-support) -- [Try it out](#try-it-out) -- [Limitations](#limitations) +# Apache Ignite Integration + +- [Overview](#overview) +- [Features](#features) + * [Distributed In-Memory Datasource](#distributed-in-memory-datasource) + * [Structured Objects](#structured-objects) + * [Distributed Training](#distributed-training) + * [Distributed File System](#distributed-file-system) + * [SSL Connection](#ssl-connection) + * [Windows Support](#windows-support) +- [Try it out](#try-it-out) + * [Ignite Dataset](#ignite-dataset) + * [IGFS](#igfs) +- [Limitations](#limitations) ## Overview -[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for -transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow. +[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed +database, caching, and processing platform for transactional, analytical, and +streaming workloads, delivering in-memory speeds at petabyte scale. This contrib +package contains an integration between Apache Ignite and TensorFlow. The +integration is based on +[tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow +side and +[Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) +from Apache Ignite side. It allows to use Apache Ignite as a data source for +neural network training, inference and all other computations supported by +TensorFlow. Another part of this module is an integration with distributed file +system based on Apache Ignite. ## Features @@ -134,6 +147,23 @@ Ignite Dataset allows using these two aspects of distributed neural network trai High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well. +### Distributed File System + +In addition to database functionality Apache Ignite provides a distributed file +system called [IGFS](https://ignite.apache.org/features/igfs.html). IGFS +delivers a similar functionality to Hadoop HDFS, but only in-memory. In fact, in +addition to its own APIs, IGFS implements Hadoop FileSystem API and can be +transparently plugged into Hadoop or Spark deployments. This contrib package +contains an integration between IGFS and TensorFlow. The integration is based +on [custom filesystem plugin](https://www.tensorflow.org/extend/add_filesys) +from TensorFlow side and +[IGFS Native API](https://ignite.apache.org/features/igfs.html) from Apache +Ignite side. It has numerous uses, for example: * Checkpoints of state can be +saved to IGFS for reliability and fault-tolerance. * Training processes +communicate with TensorBoard by writing event files to a directory, which +TensorBoard watches. IGFS allows this communication to work even when +TensorBoard runs in a different process or machine. + ### SSL Connection Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation. @@ -141,9 +171,12 @@ Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikip ```python >>> import tensorflow as tf >>> from tensorflow.contrib.ignite import IgniteDataset ->>> ->>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite") ->>> ... +>>> +>>> dataset = IgniteDataset(cache_name="IMAGES", + certfile="client.pem", + cert_password="password", + username="ignite", + password="ignite") ``` ### Windows Support @@ -152,7 +185,16 @@ Ignite Dataset is fully compatible with Windows. You can use it as part of Tenso ## Try it out -The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine: +Following examples will help you to easily start working with this module. + +### Ignite Dataset + +The simplest way to try Ignite Dataset is to run a +[Docker](https://www.docker.com/) container with Apache Ignite and loaded +[MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with +it using Ignite Dataset. Such container is available on Docker Hub: +[dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). +You need to start this container on your machine: ``` docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist @@ -162,6 +204,35 @@ After that you will be able to work with it following way: ![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist") +### IGFS + +The simplest way to try IGFS with TensorFlow is to run +[Docker](https://www.docker.com/) container with Apache Ignite and enabled IGFS +and then interruct with it using TensorFlow +[tf.gfile](https://www.tensorflow.org/api_docs/python/tf/gfile). Such container +is available on Docker Hub: +[dmitrievanthony/ignite-with-igfs](https://hub.docker.com/r/dmitrievanthony/ignite-with-igfs/). +You need to start this container on your machine: + +``` +docker run -it -p 10500:10500 dmitrievanthony/ignite-with-igfs +``` + +After that you will be able to work with it following way: + +```python +>>> import tensorflow as tf +>>> import tensorflow.contrib.ignite.python.ops.igfs_ops +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='w') as w: +>>> w.write("Hello, world!") +>>> +>>> with tf.gfile.Open("igfs:///hello.txt", mode='r') as r: +>>> print(r.read()) + +Hello, world! +``` + ## Limitations Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure. diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h similarity index 67% rename from tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h index 46df3e39dc4ec6dd4ef5730a184264eaa9fc5872..aac950fcc2aaf016959bbda876ac93df4baea417 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ #include #include "tensorflow/core/platform/byte_order.h" @@ -25,76 +25,75 @@ class ByteSwapper { public: ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; } - inline void SwapIfRequiredInt16(int16_t *x) const { + void SwapIfRequiredInt16(int16_t *x) const { if (swap_) { Swap16(x); } } - inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const { + void SwapIfRequiredUnsignedInt16(uint16_t *x) const { if (swap_) { Swap16(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt32(int32_t *x) const { + void SwapIfRequiredInt32(int32_t *x) const { if (swap_) { Swap32(x); } } - inline void SwapIfRequiredFloat(float *x) const { + void SwapIfRequiredFloat(float *x) const { if (swap_) { Swap32(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt64(int64_t *x) const { + void SwapIfRequiredInt64(int64_t *x) const { if (swap_) { Swap64(x); } } - inline void SwapIfRequiredDouble(double *x) const { + void SwapIfRequiredDouble(double *x) const { if (swap_) { Swap64(reinterpret_cast(x)); } } - inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { + void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(&x[i]); } } - inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, - int32_t length) const { + void SwapIfRequiredUnsignedInt16Arr(uint16_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap16(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { + void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(&x[i]); } } - inline void SwapIfRequiredFloatArr(float *x, int32_t length) const { + void SwapIfRequiredFloatArr(float *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap32(reinterpret_cast(&x[i])); } } - inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { + void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(&x[i]); } } - inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const { + void SwapIfRequiredDoubleArr(double *x, int32_t length) const { if (swap_) { for (int32_t i = 0; i < length; i++) Swap64(reinterpret_cast(&x[i])); @@ -102,16 +101,16 @@ class ByteSwapper { } private: - inline void Swap16(int16_t *x) const { + void Swap16(int16_t *x) const { *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF); } - inline void Swap32(int32_t *x) const { + void Swap32(int32_t *x) const { *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) | (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF); } - inline void Swap64(int64_t *x) const { + void Swap64(int64_t *x) const { *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) | (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) | (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) | @@ -123,4 +122,4 @@ class ByteSwapper { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_BYTE_SWAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_client.h similarity index 74% rename from tensorflow/contrib/ignite/kernels/ignite_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_client.h index 459b50b48fd95ad105bccaca4076160e0ef152ee..0da80769260d065c4ac6601c0e5cd7050b6b61cb 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -32,44 +32,44 @@ class Client { virtual Status ReadData(uint8_t *buf, const int32_t length) = 0; virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0; - inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); } + Status ReadByte(uint8_t *data) { return ReadData(data, 1); } - inline Status ReadShort(int16_t *data) { + Status ReadShort(int16_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2)); byte_swapper_.SwapIfRequiredInt16(data); return Status::OK(); } - inline Status ReadInt(int32_t *data) { + Status ReadInt(int32_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4)); byte_swapper_.SwapIfRequiredInt32(data); return Status::OK(); } - inline Status ReadLong(int64_t *data) { + Status ReadLong(int64_t *data) { TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8)); byte_swapper_.SwapIfRequiredInt64(data); return Status::OK(); } - inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } + Status WriteByte(const uint8_t data) { return WriteData(&data, 1); } - inline Status WriteShort(const int16_t data) { + Status WriteShort(const int16_t data) { int16_t tmp = data; byte_swapper_.SwapIfRequiredInt16(&tmp); return WriteData((uint8_t *)&tmp, 2); } - inline Status WriteInt(const int32_t data) { + Status WriteInt(const int32_t data) { int32_t tmp = data; byte_swapper_.SwapIfRequiredInt32(&tmp); return WriteData((uint8_t *)&tmp, 4); } - inline Status WriteLong(const int64_t data) { + Status WriteLong(const int64_t data) { int64_t tmp = data; byte_swapper_.SwapIfRequiredInt64(&tmp); return WriteData((uint8_t *)&tmp, 8); @@ -81,4 +81,4 @@ class Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h similarity index 80% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client.h rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h index 75424c19ee4b7df5378aa23cb41db1752e8d0651..546583246042855d179ebbb18b7dca711063b3f4 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" namespace tensorflow { @@ -40,4 +40,4 @@ class PlainClient : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_PLAIN_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc index cf672942c61e1239332711db12e62088737c4f41..54efb5b61761708a28dd031b8321ffba9a53ffa9 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_unix.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc index dad5aace5fabe1df58bb9579bf578f4c35324315..a99a3ada558e51c13ed47eb72911eb5862e71a60 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_plain_client_windows.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" #define WIN32_LEAN_AND_MEAN #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc index ceb479b0846574a35d86002ebb9c3e8e1d3687ac..8f09c24a3bedda524264f30282a0ad019d515540 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include #include diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h similarity index 82% rename from tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h rename to tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h index 0406644bbaab3de816540ce85e84b489ea9fff12..543e03d1efc3ff186c9db399af18f7aa8ad2c450 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h +++ b/tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" #include @@ -48,4 +48,4 @@ class SslWrapper : public Client { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_CLIENT_IGNITE_SSL_WRAPPER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc similarity index 99% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc index 2c8a7d44b07b43f788bcbc0850b5162cc14dd951..4218ec05f2c3486dd91e2188b674e01d6aadaa2b 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h index eb1f856643a790de6acaa82d4b8ad894fd364376..3e8a1a19623fab3e027db16228e0228e8ec4989a 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ #include -#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_byte_swapper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -78,4 +78,4 @@ enum ObjectType { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_BINARY_OBJECT_PARSER_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc index c4a7d3c513a796c9d95b371bedc609fd75188817..ace96e7b09fcf314757367baed66f622b294e43c 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h similarity index 91% rename from tensorflow/contrib/ignite/kernels/ignite_dataset.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h index 66bfdf2e2a168e59cd2fec8e2ac5b8fd482d5c15..db3bafb11f2a0047c22ece6d2bc1722afaa5ffdf 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ #include "tensorflow/core/framework/dataset.h" @@ -60,4 +60,4 @@ class IgniteDataset : public DatasetBase { } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc similarity index 98% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc index 5da9127aa6a3a4bc16347e6890cc1ba44406c0d5..ce8972f1e7fd59235556cb9514011f0b836077de 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h" -#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_ssl_wrapper.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h similarity index 87% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h index c499e2c9ccfac5c15db08c8fd8b26c37aa0404f3..5868c2cb67f9d5c91654db8cf4bb4bbc072fc1ac 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ -#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_client.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/client/ignite_client.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -96,4 +96,4 @@ constexpr int32_t kMinResLength = 12; } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_ +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_DATASET_IGNITE_DATASET_ITERATOR_H_ diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc similarity index 97% rename from tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc rename to tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc index f75b1c5ff55ca9ee493148ff79c2edd4b15ac42a..f2108775e29b53765138dcd971bec89d7a10ce40 100644 --- a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h" -#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" +#include "tensorflow/contrib/ignite/kernels/dataset/ignite_dataset.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/strings/numbers.h" diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae2dbcc2cf5d0ae7e09a26a199dc0c3c80fe22c1 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.cc @@ -0,0 +1,331 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/file_system_helper.h" + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" + +namespace tensorflow { + +static string GetEnvOrElse(const string &env, string default_value) { + const char *env_c_str = env.c_str(); + return getenv(env_c_str) != nullptr ? getenv(env_c_str) : default_value; +} + +static string MakeRelative(const string &a, const string &b) { + string max = a; + string min = b; + bool first = b.size() > a.size(); + + if (first) { + max = b; + min = a; + } + + auto r = mismatch(min.begin(), min.end(), max.begin()); + return string((first ? r.first : r.second), first ? min.end() : max.end()); +} + +string IGFS::TranslateName(const string &name) const { + StringPiece scheme, namenode, path; + io::ParseURI(name, &scheme, &namenode, &path); + return string(path.data(), path.length()); +} + +IGFS::IGFS() + : host_(GetEnvOrElse("IGFS_HOST", "localhost")), + port_([] { + int port; + if (strings::safe_strto32(GetEnvOrElse("IGFS_PORT", "10500").c_str(), + &port)) { + return port; + } else { + LOG(WARNING) + << "IGFS_PORT environment variable had an invalid value: " + << getenv("IGFS_PORT") << "\nUsing default port 10500."; + return 10500; + } + }()), + fs_name_(GetEnvOrElse("IGFS_FS_NAME", "default_fs")) { + LOG(INFO) << "IGFS created [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +IGFS::~IGFS() { + LOG(INFO) << "IGFS destroyed [host=" << host_ << ", port=" << port_ + << ", fs_name=" << fs_name_ << "]"; +} + +Status IGFS::NewRandomAccessFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse open_read_response(true); + TF_RETURN_IF_ERROR(client->OpenRead(&open_read_response, path)); + + int64 resource_id = open_read_response.res.stream_id; + result->reset(new IGFSRandomAccessFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New random access file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewWritableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + } + + CtrlResponse open_create_resp(false); + TF_RETURN_IF_ERROR(client->OpenCreate(&open_create_resp, path)); + + int64 resource_id = open_create_resp.res.stream_id; + result->reset(new IGFSWritableFile(path, resource_id, std::move(client))); + + LOG(INFO) << "New writable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewAppendableFile(const string &file_name, + std::unique_ptr *result) { + std::unique_ptr client = CreateClient(); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, file_name)); + + if (exists_response.res.exists) { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, file_name, false)); + } + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client->OpenAppend(&open_append_resp, file_name)); + + result->reset(new IGFSWritableFile(TranslateName(file_name), + open_append_resp.res.stream_id, + std::move(client))); + + LOG(INFO) << "New appendable file completed successfully [file_name=" + << file_name << "]"; + + return Status::OK(); +} + +Status IGFS::NewReadOnlyMemoryRegionFromFile( + const string &file_name, std::unique_ptr *result) { + return errors::Unimplemented("IGFS does not support ReadOnlyMemoryRegion"); +} + +Status IGFS::FileExists(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse exists_response(false); + TF_RETURN_IF_ERROR(client->Exists(&exists_response, path)); + + if (!exists_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "File exists completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetChildren(const string &file_name, std::vector *result) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + path = path + "/"; + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_paths_response(false); + TF_RETURN_IF_ERROR(client->ListPaths(&list_paths_response, path)); + + *result = std::vector(); + std::vector entries = list_paths_response.res.entries; + + for (IGFSPath &value : entries) + result->push_back(MakeRelative(value.path, path)); + + LOG(INFO) << "Get children completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetMatchingPaths(const string &pattern, + std::vector *results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + +Status IGFS::DeleteFile(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, false)); + + if (!del_response.res.exists) + return errors::NotFound("File ", path, " not found"); + + LOG(INFO) << "Delete file completed successfully [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::CreateDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + const string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse mkdir_response(false); + TF_RETURN_IF_ERROR(client->MkDir(&mkdir_response, path)); + + if (!mkdir_response.res.successful) + return errors::Unknown("Can't create directory ", path); + + LOG(INFO) << "Create dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::DeleteDir(const string &file_name) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse list_files_response(false); + TF_RETURN_IF_ERROR(client->ListFiles(&list_files_response, path)); + + if (!list_files_response.res.entries.empty()) { + return errors::FailedPrecondition("Can't delete a non-empty directory"); + } else { + CtrlResponse del_response(false); + TF_RETURN_IF_ERROR(client->Delete(&del_response, path, true)); + } + + LOG(INFO) << "Delete dir completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::GetFileSize(const string &file_name, uint64 *size) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + *size = info_response.res.file_info.length; + + LOG(INFO) << "Get file size completed successful [file_name=" << file_name + << "]"; + + return Status::OK(); +} + +Status IGFS::RenameFile(const string &src, const string &dst) { + std::unique_ptr client = CreateClient(); + string src_path = TranslateName(src); + string dst_path = TranslateName(dst); + + if (FileExists(dst).ok()) DeleteFile(dst); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse rename_response(false); + TF_RETURN_IF_ERROR(client->Rename(&rename_response, src_path, dst_path)); + + if (!rename_response.res.successful) + return errors::NotFound("File ", src_path, " not found"); + + LOG(INFO) << "Rename file completed successful [src=" << src + << ", dst=" << dst << "]"; + + return Status::OK(); +} + +Status IGFS::Stat(const string &file_name, FileStatistics *stats) { + std::unique_ptr client = CreateClient(); + string path = TranslateName(file_name); + + CtrlResponse handshake_response(true); + TF_RETURN_IF_ERROR(client->Handshake(&handshake_response)); + + CtrlResponse info_response(false); + TF_RETURN_IF_ERROR(client->Info(&info_response, path)); + + IGFSFile info = info_response.res.file_info; + + *stats = FileStatistics(info.length, info.modification_time * 1000000, + (info.flags & 0x1) != 0); + + LOG(INFO) << "Stat completed successful [file_name=" << file_name << "]"; + + return Status::OK(); +} + +std::unique_ptr IGFS::CreateClient() const { + return std::unique_ptr( + new IGFSClient(host_, port_, fs_name_, "")); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs.h b/tensorflow/contrib/ignite/kernels/igfs/igfs.h new file mode 100644 index 0000000000000000000000000000000000000000..4c347e937f75e8eea108811e6a3189412e22a982 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFS : public FileSystem { + public: + IGFS(); + ~IGFS(); + Status NewRandomAccessFile( + const string& file_name, + std::unique_ptr* result) override; + Status NewWritableFile(const string& fname, + std::unique_ptr* result) override; + Status NewAppendableFile(const string& fname, + std::unique_ptr* result) override; + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, + std::unique_ptr* result) override; + Status FileExists(const string& fname) override; + Status GetChildren(const string& dir, std::vector* result) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + Status DeleteFile(const string& fname) override; + Status CreateDir(const string& name) override; + Status DeleteDir(const string& name) override; + Status GetFileSize(const string& fname, uint64* size) override; + Status RenameFile(const string& src, const string& target) override; + Status Stat(const string& fname, FileStatistics* stat) override; + string TranslateName(const string& name) const override; + + private: + std::unique_ptr CreateClient() const; + + const string host_; + const int port_; + const string fs_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f97c34fdd8b026a04506fd0ef9f3cc74129a9da --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" + +namespace tensorflow { + +IGFSClient::IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name) + : fs_name_(fs_name), + user_name_(user_name), + client_(ExtendedTCPClient(host, port, true)) { + client_.Connect(); +} + +IGFSClient::~IGFSClient() { client_.Disconnect(); } + +Status IGFSClient::SendRequestGetResponse(const Request &request, + Response *response) { + TF_RETURN_IF_ERROR(request.Write(&client_)); + client_.reset(); + + if (response != nullptr) { + TF_RETURN_IF_ERROR(response->Read(&client_)); + client_.reset(); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h new file mode 100644 index 0000000000000000000000000000000000000000..bbec7b000779be8772e850a556affffa1b3b6803 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_client.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +class IGFSClient { + public: + IGFSClient(const string &host, int port, const string &fs_name, + const string &user_name); + ~IGFSClient(); + + Status Handshake(CtrlResponse *res) { + return SendRequestGetResponse(HandshakeRequest(fs_name_, {}), res); + } + + Status ListFiles(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListFilesRequest(user_name_, path), res); + } + + Status ListPaths(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ListPathsRequest(user_name_, path), res); + } + + Status Info(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(InfoRequest(user_name_, path), res); + } + + Status OpenCreate(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenCreateRequest(user_name_, path), res); + } + + Status OpenAppend(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenAppendRequest(user_name_, path), res); + } + + Status OpenRead(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(OpenReadRequest(user_name_, path), res); + } + + Status Exists(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(ExistsRequest(user_name_, path), res); + } + + Status MkDir(CtrlResponse *res, const string &path) { + return SendRequestGetResponse(MakeDirectoriesRequest(user_name_, path), + res); + } + + Status Delete(CtrlResponse *res, const string &path, + bool recursive) { + return SendRequestGetResponse(DeleteRequest(user_name_, path, recursive), + res); + } + + Status WriteBlock(int64_t stream_id, const uint8_t *data, int32_t len) { + return SendRequestGetResponse(WriteBlockRequest(stream_id, data, len), + nullptr); + } + + Status ReadBlock(ReadBlockCtrlResponse *res, int64_t stream_id, int64_t pos, + int32_t length) { + return SendRequestGetResponse(ReadBlockRequest(stream_id, pos, length), + res); + } + + Status Close(CtrlResponse *res, int64_t stream_id) { + return SendRequestGetResponse(CloseRequest(stream_id), res); + } + + Status Rename(CtrlResponse *res, const string &source, + const string &dest) { + return SendRequestGetResponse(RenameRequest(user_name_, source, dest), res); + } + + private: + Status SendRequestGetResponse(const Request &request, Response *response); + + const string fs_name_; + const string user_name_; + ExtendedTCPClient client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea63436546d8b244b921206f9577c91b6578a775 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +ExtendedTCPClient::ExtendedTCPClient(const string &host, int port, + bool big_endian) + : PlainClient(host, port, big_endian), pos_(0) {} + +Status ExtendedTCPClient::ReadData(uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::ReadData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteData(const uint8_t *buf, const int32_t length) { + TF_RETURN_IF_ERROR(PlainClient::WriteData(buf, length)); + pos_ += length; + + return Status::OK(); +} + +Status ExtendedTCPClient::Ignore(int n) { + uint8_t buf[n]; + return ReadData(buf, n); +} + +Status ExtendedTCPClient::SkipToPos(int target_pos) { + return Ignore(std::max(0, target_pos - pos_)); +} + +Status ExtendedTCPClient::ReadBool(bool *res) { + uint8_t buf = 0; + TF_RETURN_IF_ERROR(ReadData(&buf, 1)); + *res = buf != 0; + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadNullableString(string *res) { + bool is_empty = false; + TF_RETURN_IF_ERROR(ReadBool(&is_empty)); + + if (!is_empty) { + TF_RETURN_IF_ERROR(ReadString(res)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::ReadString(string *res) { + int16_t length; + TF_RETURN_IF_ERROR(ReadShort(&length)); + + uint8_t *buf = new uint8_t[length]; + Status status = ReadData(buf, length); + + if (status.ok()) res->assign(reinterpret_cast(buf), length); + + delete[] buf; + return status; +} + +Status ExtendedTCPClient::ReadStringMap(std::map *res) { + int size; + TF_RETURN_IF_ERROR(ReadInt(&size)); + + for (int i = 0; i < size; i++) { + string key; + string val; + TF_RETURN_IF_ERROR(ReadString(&key)); + TF_RETURN_IF_ERROR(ReadString(&val)); + + res->insert(std::pair(std::move(key), std::move(val))); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteSize(std::map::size_type s) { + return WriteInt(s); +} + +Status ExtendedTCPClient::FillWithZerosUntil(int n) { + int to_skip = std::max(0, n - pos_); + + for (int i = 0; i < to_skip; i++) { + TF_RETURN_IF_ERROR(WriteByte(0)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteBool(bool val) { + return WriteByte((char)(val ? 1 : 0)); +} + +Status ExtendedTCPClient::WriteString(string str) { + if (!str.empty()) { + TF_RETURN_IF_ERROR(WriteBool(false)); + size_t l = str.length(); + if (l > std::numeric_limits::max()) + return errors::InvalidArgument("String is too long"); + + TF_RETURN_IF_ERROR(WriteShort(l)); + TF_RETURN_IF_ERROR(WriteData(reinterpret_cast(str.c_str()), + str.length())); + } else { + TF_RETURN_IF_ERROR(WriteBool(true)); + } + + return Status::OK(); +} + +Status ExtendedTCPClient::WriteStringMap(std::map map) { + std::map::size_type size = map.size(); + TF_RETURN_IF_ERROR(WriteSize(size)); + + for (auto &x : map) { + TF_RETURN_IF_ERROR(WriteString(x.first)); + TF_RETURN_IF_ERROR(WriteString(x.second)); + } + + return Status::OK(); +} + +void ExtendedTCPClient::reset() { pos_ = 0; } + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h new file mode 100644 index 0000000000000000000000000000000000000000..c5de342fd0c20cf5b01b647756797631b8a3f203 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ + +#include "tensorflow/contrib/ignite/kernels/client/ignite_plain_client.h" + +namespace tensorflow { + +class ExtendedTCPClient : public PlainClient { + public: + ExtendedTCPClient(const string &host, int port, bool big_endian); + Status ReadData(uint8_t *buf, const int32_t length) override; + Status WriteData(const uint8_t *buf, const int32_t length) override; + Status Ignore(int n); + Status SkipToPos(int target_pos); + Status ReadBool(bool *res); + Status ReadNullableString(string *res); + Status ReadString(string *res); + Status ReadStringMap(std::map *res); + Status WriteSize(std::map::size_type s); + Status FillWithZerosUntil(int n); + Status WriteBool(bool val); + Status WriteString(string str); + Status WriteStringMap(std::map map); + void reset(); + + private: + int pos_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_EXTENDED_TCP_CLIENT_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c63f40f35fa53bc51c44f574df50ad0c79fba91 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.cc @@ -0,0 +1,344 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +Status IGFSPath::Read(ExtendedTCPClient *client) { + return client->ReadNullableString(&path); +} + +Status IGFSFile::Read(ExtendedTCPClient *client) { + int32_t block_size; + int64_t group_block_size; + std::map properties = {}; + int64_t access_time; + + bool has_path; + TF_RETURN_IF_ERROR(client->ReadBool(&has_path)); + if (has_path) { + IGFSPath path = {}; + TF_RETURN_IF_ERROR(path.Read(client)); + } + + TF_RETURN_IF_ERROR(client->ReadInt(&block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&group_block_size)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + TF_RETURN_IF_ERROR(client->ReadStringMap(&properties)); + TF_RETURN_IF_ERROR(client->ReadLong(&access_time)); + TF_RETURN_IF_ERROR(client->ReadLong(&modification_time)); + TF_RETURN_IF_ERROR(client->ReadByte(&flags)); + + return Status::OK(); +} + +Request::Request(int32_t command_id) : command_id_(command_id) {} + +Status Request::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(24)); + + return Status::OK(); +} + +Status Response::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->Ignore(1)); + TF_RETURN_IF_ERROR(client->SkipToPos(8)); + TF_RETURN_IF_ERROR(client->ReadInt(&req_id)); + TF_RETURN_IF_ERROR(client->SkipToPos(24)); + TF_RETURN_IF_ERROR(client->ReadInt(&res_type)); + + bool has_error; + TF_RETURN_IF_ERROR(client->ReadBool(&has_error)); + + if (has_error) { + int32_t error_code; + string error_msg; + TF_RETURN_IF_ERROR(client->ReadString(&error_msg)); + TF_RETURN_IF_ERROR(client->ReadInt(&error_code)); + + return errors::Unknown("Error [code=", error_code, ", message=\"", + error_msg, "\"]"); + } + + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + 5)); + TF_RETURN_IF_ERROR(client->ReadInt(&length)); + TF_RETURN_IF_ERROR(client->SkipToPos(header_size_ + response_header_size_)); + + return Status::OK(); +} + +PathCtrlRequest::PathCtrlRequest(int32_t command_id_, const string &user_name, + const string &path, + const string &destination_path, bool flag, + bool collocate, + const std::map &properties) + : Request(command_id_), + user_name_(user_name), + path_(path), + destination_path_(destination_path), + flag_(flag), + collocate_(collocate), + props_(properties) {} + +Status PathCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(user_name_)); + TF_RETURN_IF_ERROR(WritePath(client, path_)); + TF_RETURN_IF_ERROR(WritePath(client, destination_path_)); + TF_RETURN_IF_ERROR(client->WriteBool(flag_)); + TF_RETURN_IF_ERROR(client->WriteBool(collocate_)); + TF_RETURN_IF_ERROR(client->WriteStringMap(props_)); + + return Status::OK(); +} + +Status PathCtrlRequest::WritePath(ExtendedTCPClient *client, + const string &path) const { + TF_RETURN_IF_ERROR(client->WriteBool(!path.empty())); + if (!path.empty()) TF_RETURN_IF_ERROR(client->WriteString(path)); + + return Status::OK(); +} + +Status StreamCtrlRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(client->WriteByte(0)); + TF_RETURN_IF_ERROR(client->FillWithZerosUntil(8)); + TF_RETURN_IF_ERROR(client->WriteInt(command_id_)); + TF_RETURN_IF_ERROR(client->WriteLong(stream_id_)); + TF_RETURN_IF_ERROR(client->WriteInt(length_)); + + return Status::OK(); +} + +StreamCtrlRequest::StreamCtrlRequest(int32_t command_id_, int64_t stream_id, + int32_t length) + : Request(command_id_), stream_id_(stream_id), length_(length) {} + +DeleteRequest::DeleteRequest(const string &user_name, const string &path, + bool flag) + : PathCtrlRequest(DELETE_ID, user_name, path, {}, flag, true, {}) {} + +Status DeleteResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +ExistsRequest::ExistsRequest(const string &user_name, const string &path) + : PathCtrlRequest(EXISTS_ID, user_name, path, {}, false, true, {}) {} + +Status ExistsResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&exists)); + + return Status::OK(); +} + +HandshakeRequest::HandshakeRequest(const string &fs_name, const string &log_dir) + : Request(HANDSHAKE_ID), fs_name_(fs_name), log_dir_(log_dir) {} + +Status HandshakeRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(Request::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteString(fs_name_)); + TF_RETURN_IF_ERROR(client->WriteString(log_dir_)); + + return Status::OK(); +} + +Status HandshakeResponse::Read(ExtendedTCPClient *client) { + int64_t block_size; + bool sampling; + + TF_RETURN_IF_ERROR(client->ReadNullableString(&fs_name)); + TF_RETURN_IF_ERROR(client->ReadLong(&block_size)); + + bool has_sampling_; + TF_RETURN_IF_ERROR(client->ReadBool(&has_sampling_)); + + if (has_sampling_) { + TF_RETURN_IF_ERROR(client->ReadBool(&sampling)); + } + + return Status::OK(); +} + +ListRequest::ListRequest(int32_t command_id_, const string &user_name, + const string &path) + : PathCtrlRequest(command_id_, user_name, path, {}, false, true, {}) {} + +ListFilesRequest::ListFilesRequest(const string &user_name, const string &path) + : ListRequest(LIST_FILES_ID, user_name, path) {} + +ListPathsRequest::ListPathsRequest(const string &user_name, const string &path) + : ListRequest(LIST_PATHS_ID, user_name, path) {} + +OpenCreateRequest::OpenCreateRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_CREATE_ID, user_name, path, {}, false, true, {}) {} + +Status OpenCreateRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteInt(replication_)); + TF_RETURN_IF_ERROR(client->WriteLong(blockSize_)); + + return Status::OK(); +} + +Status OpenCreateResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenAppendRequest::OpenAppendRequest(const string &user_name, + const string &path) + : PathCtrlRequest(OPEN_APPEND_ID, user_name, path, {}, false, true, {}) {} + +Status OpenAppendRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + return Status::OK(); +} + +Status OpenAppendResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + + return Status::OK(); +} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path, + bool flag, + int32_t sequential_reads_before_prefetch) + : PathCtrlRequest(OPEN_READ_ID, user_name, path, {}, flag, true, {}), + sequential_reads_before_prefetch_(sequential_reads_before_prefetch) {} + +OpenReadRequest::OpenReadRequest(const string &user_name, const string &path) + : OpenReadRequest(user_name, path, false, 0) {} + +Status OpenReadRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(PathCtrlRequest::Write(client)); + + if (flag_) { + TF_RETURN_IF_ERROR(client->WriteInt(sequential_reads_before_prefetch_)); + } + + return Status::OK(); +} + +Status OpenReadResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadLong(&stream_id)); + TF_RETURN_IF_ERROR(client->ReadLong(&length)); + + return Status::OK(); +} + +InfoRequest::InfoRequest(const string &user_name, const string &path) + : PathCtrlRequest(INFO_ID, user_name, path, {}, false, true, {}) {} + +Status InfoResponse::Read(ExtendedTCPClient *client) { + file_info = IGFSFile(); + TF_RETURN_IF_ERROR(file_info.Read(client)); + + return Status::OK(); +} + +MakeDirectoriesRequest::MakeDirectoriesRequest(const string &user_name, + const string &path) + : PathCtrlRequest(MKDIR_ID, user_name, path, {}, false, true, {}) {} + +Status MakeDirectoriesResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +CloseRequest::CloseRequest(int64_t streamId) + : StreamCtrlRequest(CLOSE_ID, streamId, 0) {} + +Status CloseResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +ReadBlockRequest::ReadBlockRequest(int64_t stream_id, int64_t pos, + int32_t length) + : StreamCtrlRequest(READ_BLOCK_ID, stream_id, length), pos(pos) {} + +Status ReadBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + + TF_RETURN_IF_ERROR(client->WriteLong(pos)); + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client, int32_t length, + uint8_t *dst) { + TF_RETURN_IF_ERROR(client->ReadData(dst, length)); + successfully_read = length; + + return Status::OK(); +} + +Status ReadBlockResponse::Read(ExtendedTCPClient *client) { + return Status::OK(); +} + +std::streamsize ReadBlockResponse::GetSuccessfullyRead() { + return successfully_read; +} + +ReadBlockCtrlResponse::ReadBlockCtrlResponse(uint8_t *dst) + : CtrlResponse(false), dst(dst) {} + +Status ReadBlockCtrlResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(Response::Read(client)); + + res = ReadBlockResponse(); + TF_RETURN_IF_ERROR(res.Read(client, length, dst)); + + return Status::OK(); +} + +WriteBlockRequest::WriteBlockRequest(int64_t stream_id, const uint8_t *data, + int32_t length) + : StreamCtrlRequest(WRITE_BLOCK_ID, stream_id, length), data(data) {} + +Status WriteBlockRequest::Write(ExtendedTCPClient *client) const { + TF_RETURN_IF_ERROR(StreamCtrlRequest::Write(client)); + TF_RETURN_IF_ERROR(client->WriteData((uint8_t *)data, length_)); + + return Status::OK(); +} + +RenameRequest::RenameRequest(const string &user_name, const string &path, + const string &destination_path) + : PathCtrlRequest(RENAME_ID, user_name, path, destination_path, false, true, + {}) {} + +Status RenameResponse::Read(ExtendedTCPClient *client) { + TF_RETURN_IF_ERROR(client->ReadBool(&successful)); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h new file mode 100644 index 0000000000000000000000000000000000000000..44a2928a2b2b48849c7ba4454e0e7848c2217b3b --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h @@ -0,0 +1,356 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_extended_tcp_client.h" + +namespace tensorflow { + +enum CommandId { + HANDSHAKE_ID = 0, + EXISTS_ID = 2, + INFO_ID = 3, + RENAME_ID = 6, + DELETE_ID = 7, + MKDIR_ID = 8, + LIST_PATHS_ID = 9, + LIST_FILES_ID = 10, + OPEN_READ_ID = 13, + OPEN_APPEND_ID = 14, + OPEN_CREATE_ID = 15, + CLOSE_ID = 16, + READ_BLOCK_ID = 17, + WRITE_BLOCK_ID = 18, +}; + +class IGFSPath { + public: + Status Read(ExtendedTCPClient *client); + + string path; +}; + +class IGFSFile { + public: + Status Read(ExtendedTCPClient *client); + + int64_t length; + int64_t modification_time; + uint8_t flags; +}; + +class Request { + public: + Request(int32_t command_id); + virtual Status Write(ExtendedTCPClient *client) const; + + protected: + const int32_t command_id_; +}; + +class Response { + public: + virtual Status Read(ExtendedTCPClient *client); + + int32_t res_type; + int32_t req_id; + int32_t length; + + protected: + static const int32_t header_size_ = 24; + static const int32_t response_header_size_ = 9; +}; + +class PathCtrlRequest : public Request { + public: + PathCtrlRequest(int32_t command_id, const string &user_name, + const string &path, const string &destination_path, bool flag, + bool collocate, const std::map &properties); + Status Write(ExtendedTCPClient *client) const override; + + protected: + Status WritePath(ExtendedTCPClient *client, const string &path) const; + + const string user_name_; + const string path_; + const string destination_path_; + const bool flag_; + const bool collocate_; + const std::map props_; +}; + +class StreamCtrlRequest : public Request { + public: + StreamCtrlRequest(int32_t command_id, int64_t stream_id, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + protected: + int64_t stream_id_; + int32_t length_; +}; + +template +class CtrlResponse : public Response { + public: + CtrlResponse(bool optional) : optional_(optional) {} + Status Read(ExtendedTCPClient *client) override { + TF_RETURN_IF_ERROR(Response::Read(client)); + + if (optional_) { + TF_RETURN_IF_ERROR(client->ReadBool(&has_content)); + + if (!has_content) return Status::OK(); + } + + res = R(); + has_content = true; + TF_RETURN_IF_ERROR(res.Read(client)); + + return Status::OK(); + } + + R res; + bool has_content; + + private: + bool optional_; +}; + +template +class ListResponse { + public: + Status Read(ExtendedTCPClient *client) { + int32_t len; + TF_RETURN_IF_ERROR(client->ReadInt(&len)); + + entries.clear(); + + for (int32_t i = 0; i < len; i++) { + T f = {}; + TF_RETURN_IF_ERROR(f.Read(client)); + entries.push_back(f); + } + + return Status::OK(); + } + + std::vector entries; +}; + +class DeleteRequest : public PathCtrlRequest { + public: + DeleteRequest(const string &user_name, const string &path, bool flag); +}; + +class DeleteResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class ExistsRequest : public PathCtrlRequest { + public: + explicit ExistsRequest(const string &user_name, const string &path); +}; + +class ExistsResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool exists; +}; + +class HandshakeRequest : public Request { + public: + HandshakeRequest(const string &fs_name, const string &log_dir); + Status Write(ExtendedTCPClient *client) const override; + + private: + string fs_name_; + string log_dir_; +}; + +class HandshakeResponse { + public: + Status Read(ExtendedTCPClient *client); + + string fs_name; +}; + +class ListRequest : public PathCtrlRequest { + public: + explicit ListRequest(int32_t command_id, const string &user_name, + const string &path); +}; + +class ListFilesRequest : public ListRequest { + public: + ListFilesRequest(const string &user_name, const string &path); +}; + +class ListFilesResponse : public ListResponse {}; + +class ListPathsRequest : public ListRequest { + public: + ListPathsRequest(const string &user_name, const string &path); +}; + +class ListPathsResponse : public ListResponse {}; + +class OpenCreateRequest : public PathCtrlRequest { + public: + OpenCreateRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + private: + int32_t replication_; + int64_t blockSize_; +}; + +class OpenCreateResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenAppendRequest : public PathCtrlRequest { + public: + explicit OpenAppendRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; +}; + +class OpenAppendResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; +}; + +class OpenReadRequest : public PathCtrlRequest { + public: + OpenReadRequest(const string &user_name, const string &path, bool flag, + int32_t seqReadsBeforePrefetch); + OpenReadRequest(const string &user_name, const string &path); + Status Write(ExtendedTCPClient *client) const override; + + protected: + /** Sequential reads before prefetch. */ + int32_t sequential_reads_before_prefetch_; +}; + +class OpenReadResponse { + public: + Status Read(ExtendedTCPClient *client); + + int64_t stream_id; + int64_t length; +}; + +class InfoRequest : public PathCtrlRequest { + public: + InfoRequest(const string &user_name, const string &path); +}; + +class InfoResponse { + public: + Status Read(ExtendedTCPClient *client); + + IGFSFile file_info; +}; + +class MakeDirectoriesRequest : public PathCtrlRequest { + public: + MakeDirectoriesRequest(const string &userName, const string &path); +}; + +class MakeDirectoriesResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +/** Stream control requests. **/ + +class CloseRequest : public StreamCtrlRequest { + public: + explicit CloseRequest(int64_t stream_id); +}; + +class CloseResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +class ReadBlockRequest : public StreamCtrlRequest { + public: + ReadBlockRequest(int64_t stream_id, int64_t pos, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + int64_t pos; +}; + +class ReadBlockResponse { + public: + Status Read(ExtendedTCPClient *client, int32_t length, uint8_t *dst); + Status Read(ExtendedTCPClient *client); + std::streamsize GetSuccessfullyRead(); + + private: + int32_t length; + std::streamsize successfully_read; +}; + +class ReadBlockCtrlResponse : public CtrlResponse { + public: + ReadBlockCtrlResponse(uint8_t *dst); + Status Read(ExtendedTCPClient *client) override; + + private: + uint8_t *dst; +}; + +class WriteBlockRequest : public StreamCtrlRequest { + public: + WriteBlockRequest(int64_t stream_id, const uint8_t *data, int32_t length); + Status Write(ExtendedTCPClient *client) const override; + + private: + const uint8_t *data; +}; + +class RenameRequest : public PathCtrlRequest { + public: + RenameRequest(const string &user_name, const string &path, + const string &destination_path); +}; + +class RenameResponse { + public: + Status Read(ExtendedTCPClient *client); + + bool successful; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_MESSAGES_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4c898f14e6d298e65f563f4493a822172c40851 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSRandomAccessFile::IGFSRandomAccessFile(const string &file_name, + int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSRandomAccessFile::~IGFSRandomAccessFile() { + CtrlResponse close_response = {false}; + Status status = client_->Close(&close_response, resource_id_); + + if (!status.ok()) LOG(ERROR) << status.ToString(); +} + +Status IGFSRandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const { + ReadBlockCtrlResponse response = ReadBlockCtrlResponse((uint8_t *)scratch); + TF_RETURN_IF_ERROR(client_->ReadBlock(&response, resource_id_, offset, n)); + + std::streamsize sz = response.res.GetSuccessfullyRead(); + if (sz == 0) return errors::OutOfRange("End of file"); + + *result = StringPiece(scratch, sz); + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h new file mode 100644 index 0000000000000000000000000000000000000000..b21369ff8a3b19774bcc743f93a5ec4ae1c9b49a --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_random_access_file.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSRandomAccessFile : public RandomAccessFile { + public: + IGFSRandomAccessFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSRandomAccessFile() override; + Status Read(uint64 offset, size_t n, StringPiece *result, + char *scratch) const override; + + private: + const string file_name_; + const int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_RANDOM_ACCESS_FILE_H_ diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc new file mode 100644 index 0000000000000000000000000000000000000000..c15ecb7deeb0cf5a8a040e0d1e4b70c732729474 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h" +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_messages.h" + +namespace tensorflow { + +IGFSWritableFile::IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client) + : file_name_(file_name), + resource_id_(resource_id), + client_(std::move(client)) {} + +IGFSWritableFile::~IGFSWritableFile() { + if (resource_id_ >= 0) { + CtrlResponse close_response = {false}; + + Status status = client_->Close(&close_response, resource_id_); + if (!status.ok()) LOG(ERROR) << status.ToString(); + } +} + +Status IGFSWritableFile::Append(StringPiece data) { + return client_->WriteBlock(resource_id_, (uint8_t *)data.data(), data.size()); +} + +Status IGFSWritableFile::Close() { + int64_t resource_to_be_closed = resource_id_; + resource_id_ = -1; + + CtrlResponse close_response = {false}; + return client_->Close(&close_response, resource_to_be_closed); +} + +Status IGFSWritableFile::Flush() { return Sync(); } + +Status IGFSWritableFile::Sync() { + CtrlResponse close_response = {false}; + TF_RETURN_IF_ERROR(client_->Close(&close_response, resource_id_)); + + CtrlResponse open_append_resp(false); + TF_RETURN_IF_ERROR(client_->OpenAppend(&open_append_resp, file_name_)); + + resource_id_ = open_append_resp.res.stream_id; + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h new file mode 100644 index 0000000000000000000000000000000000000000..b406db17e0e350e2cef610bb05c40f658e100140 --- /dev/null +++ b/tensorflow/contrib/ignite/kernels/igfs/igfs_writable_file.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ +#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ + +#include "tensorflow/contrib/ignite/kernels/igfs/igfs_client.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +class IGFSWritableFile : public WritableFile { + public: + IGFSWritableFile(const string &file_name, int64_t resource_id, + std::unique_ptr &&client); + ~IGFSWritableFile() override; + Status Append(StringPiece data) override; + Status Close() override; + Status Flush() override; + Status Sync() override; + + private: + const string file_name_; + int64_t resource_id_; + std::unique_ptr client_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGFS_IGFS_WRITABLE_FILE_H_ diff --git a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc b/tensorflow/contrib/ignite/ops/igfs_ops.cc similarity index 62% rename from tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc rename to tensorflow/contrib/ignite/ops/igfs_ops.cc index f3b24b2341e590adfbeac1a18b6a65fbfd34f598..473bddff08b339d3b76a33d40fe34486acdbe151 100644 --- a/tensorflow/core/kernels/fuzzing/decode_jpeg_fuzz.cc +++ b/tensorflow/contrib/ignite/ops/igfs_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2016 Google Inc. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,17 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" +#include "tensorflow/core/platform/env.h" -namespace tensorflow { -namespace fuzzing { +#include "tensorflow/contrib/ignite/kernels/igfs/igfs.h" -class FuzzDecodeJpeg : public FuzzStringInputOp { - SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeJpeg); -}; +namespace tensorflow { -STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeJpeg); +REGISTER_FILE_SYSTEM("igfs", IGFS); -} // end namespace fuzzing -} // end namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1d6707d6400a7cd84016150d20973809aca20e --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading IGFS ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/ignite/python/ops/igfs_ops.py b/tensorflow/contrib/ignite/python/ops/igfs_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..12b973b707730f6ba5b057b74a46b27d8f973ede --- /dev/null +++ b/tensorflow/contrib/ignite/python/ops/igfs_ops.py @@ -0,0 +1,40 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignite File System for checkpointing and communication with TensorBoard. + +Apache Ignite is a memory-centric distributed database, caching, and +processing platform for transactional, analytical, and streaming workloads, +delivering in-memory speeds at petabyte scale. In addition to database +functionality Apache Ignite provides a distributed file system called +IGFS (https://ignite.apache.org/features/igfs.html). IGFS delivers a similar +functionality to Hadoop HDFS, but only in-memory. In fact, in addition to +its own APIs, IGFS implements Hadoop FileSystem API and can be transparently +plugged into Hadoop or Spark deployments. This contrib package contains an +integration between IGFS and TensorFlow. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +file_system_library = os.path.join(resource_loader.get_data_files_path(), + "../../_ignite_ops.so") +load_library.load_file_system_library(file_system_library) diff --git a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py index c9af7386cf0a26ed1a950130aa36caa7fb831fd0..e450e2d84ba31a7de925fdb78fc972a592c6ad8c 100644 --- a/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py +++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py @@ -21,4 +21,4 @@ from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader _dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) + resource_loader.get_path_to_datafile("../../_ignite_ops.so")) diff --git a/tensorflow/contrib/signal/python/__init__.py b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh old mode 100644 new mode 100755 similarity index 73% rename from tensorflow/contrib/signal/python/__init__.py rename to tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh index e672d1146c53a813613c9076c0cb6056f7081441..5e39e16c05290f6b5786421670c69a3bd1e27add --- a/tensorflow/contrib/signal/python/__init__.py +++ b/tensorflow/contrib/ignite/python/tests/bin/start-igfs.sh @@ -1,4 +1,5 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Signal ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-igfs.xml & +sleep 5 # Wait Apache Ignite to be started + +tail -f nohup.out diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d81bf33226cad0d5cc0ea1fb5c5b55672494976 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-igfs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 127.0.0.1 + + + + + + + + + diff --git a/tensorflow/contrib/ignite/python/tests/igfs_test.py b/tensorflow/contrib/ignite/python/tests/igfs_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cacfc568942e20200b7daf10599dde513a4a0a68 --- /dev/null +++ b/tensorflow/contrib/ignite/python/tests/igfs_test.py @@ -0,0 +1,215 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for IGFS.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.contrib.ignite.python.ops.igfs_ops # pylint: disable=unused-import +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test + + +class IGFSTest(test.TestCase): + """The Apache Ignite servers have to setup before the test and tear down + + after the test manually. The docker engine has to be installed. + + To setup Apache Ignite servers: + $ bash start_ignite.sh + + To tear down Apache Ignite servers: + $ bash stop_ignite.sh + """ + + def test_create_file(self): + """Test create file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_create_file/1" + self.assertFalse(gfile.Exists(file_name)) + # Create file. + with gfile.Open(file_name, mode="w") as w: + w.write("") + # Check that file was created. + self.assertTrue(gfile.Exists(file_name)) + + def test_write_read_file(self): + """Test write/read file. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_write_read_file/1" + rows = 10000 + self.assertFalse(gfile.Exists(file_name)) + # Write data. + with gfile.Open(file_name, mode="w") as w: + for i in range(rows): + w.write("This is row\n") + # Read data. + with gfile.Open(file_name, mode="r") as r: + lines = r.readlines() + # Check that data is equal. + self.assertEqual(rows, len(lines)) + for i in range(rows): + self.assertEqual("This is row\n", lines[i]) + + def test_delete_recursively(self): + """Test delete recursively. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_delete_recursively/" + file_name = "igfs:///test_delete_recursively/1" + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + gfile.MkDir(dir_name) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(dir_name)) + self.assertTrue(gfile.Exists(file_name)) + # Delete directory recursively. + gfile.DeleteRecursively(dir_name) + # Check that directory was deleted. + self.assertFalse(gfile.Exists(dir_name)) + self.assertFalse(gfile.Exists(file_name)) + + def test_copy(self): + """Test copy. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_copy/1" + dst_file_name = "igfs:///test_copy/2" + self.assertFalse(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + self.assertFalse(gfile.Exists(dst_file_name)) + # Copy file. + gfile.Copy(src_file_name, dst_file_name) + # Check that files are identical. + self.assertTrue(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_is_directory(self): + """Test is directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_is_directory/1" + file_name = "igfs:///test_is_directory/2" + with gfile.Open(file_name, mode="w") as w: + w.write("") + gfile.MkDir(dir_name) + # Check that directory is a directory. + self.assertTrue(gfile.IsDirectory(dir_name)) + # Check that file is not a directory. + self.assertFalse(gfile.IsDirectory(file_name)) + + def test_list_directory(self): + """Test list directory. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_list_directory/" + file_names = [ + "igfs:///test_list_directory/1", "igfs:///test_list_directory/2/3" + ] + ch_dir_names = [ + "igfs:///test_list_directory/4", + ] + for file_name in file_names: + with gfile.Open(file_name, mode="w") as w: + w.write("") + for ch_dir_name in ch_dir_names: + gfile.MkDir(ch_dir_name) + ls_expected_result = file_names + ch_dir_names + # Get list of files in directory. + ls_result = gfile.ListDirectory(dir_name) + # Check that list of files is correct. + self.assertEqual(len(ls_expected_result), len(ls_result)) + for e in ["1", "2", "4"]: + self.assertTrue(e in ls_result) + + def test_make_dirs(self): + """Test make dirs. + + """ + # Setup and check preconditions. + dir_name = "igfs:///test_make_dirs/" + self.assertFalse(gfile.Exists(dir_name)) + # Make directory. + gfile.MkDir(dir_name) + # Check that directory was created. + self.assertTrue(gfile.Exists(dir_name)) + + def test_remove(self): + """Test remove. + + """ + # Setup and check preconditions. + file_name = "igfs:///test_remove/1" + self.assertFalse(gfile.Exists(file_name)) + with gfile.Open(file_name, mode="w") as w: + w.write("") + self.assertTrue(gfile.Exists(file_name)) + # Remove file. + gfile.Remove(file_name) + # Check that file was removed. + self.assertFalse(gfile.Exists(file_name)) + + def test_rename_file(self): + """Test rename file. + + """ + # Setup and check preconditions. + src_file_name = "igfs:///test_rename_file/1" + dst_file_name = "igfs:///test_rename_file/2" + with gfile.Open(src_file_name, mode="w") as w: + w.write("42") + self.assertTrue(gfile.Exists(src_file_name)) + # Rename file. + gfile.Rename(src_file_name, dst_file_name) + # Check that only new name of file is available. + self.assertFalse(gfile.Exists(src_file_name)) + self.assertTrue(gfile.Exists(dst_file_name)) + with gfile.Open(dst_file_name, mode="r") as r: + data = r.read() + self.assertEqual("42", data) + + def test_rename_dir(self): + """Test rename dir. + + """ + # Setup and check preconditions. + src_dir_name = "igfs:///test_rename_dir/1" + dst_dir_name = "igfs:///test_rename_dir/2" + gfile.MkDir(src_dir_name) + # Rename directory. + gfile.Rename(src_dir_name, dst_dir_name) + # Check that only new name of directory is available. + self.assertFalse(gfile.Exists(src_dir_name)) + self.assertTrue(gfile.Exists(dst_dir_name)) + self.assertTrue(gfile.IsDirectory(dst_dir_name)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh index a67bd44f2fb0d654ba07f022a5070c68df8e2ede..112e0dea844620de600e277bff3685dd7c42c49c 100755 --- a/tensorflow/contrib/ignite/python/tests/start_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh @@ -20,3 +20,7 @@ SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )" # Start Apache Ignite with plain client listener. docker run -itd --name ignite-plain -p 42300:10800 \ -v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh + +# Start Apache Ignite with IGFS. +docker run -itd --name ignite-igfs -p 10500:10500 \ +-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-igfs.sh \ No newline at end of file diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh index 8f03dbd1ede61f548d3de9d9738f97667e75df3c..35b0f32d1b3e1373a231ff23f2b40c8ccc417baf 100755 --- a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh +++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh @@ -15,5 +15,4 @@ # ============================================================================== docker rm -f ignite-plain -docker rm -f ignite-ssl -docker rm -f ignite-ssl-auth +docker rm -f ignite-igfs \ No newline at end of file diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index c8812d4b23f94102d093db878a709b090a3318d6..588f15b867c1fedbadd5a5d945d870a356549468 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,7 +70,10 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows + tags = [ + "no_windows", # TODO: needs investigation on Windows + "noasan", # TODO(b/119323169) + ], deps = [ ":_typecheck", ":core", diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index eab93f2cc5ed3d5179a58fa717d8b83d0c4d7337..e779eff68901af7042deb5c09b78a230e0d06d02 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -42,6 +42,7 @@ tensorflow/core/kernels/conv_grad_filter_ops.cc tensorflow/core/kernels/conv_grad_input_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/conv_ops.cc +tensorflow/core/kernels/conv_ops_3d.cc tensorflow/core/kernels/conv_ops_fused.cc tensorflow/core/kernels/conv_ops_using_gemm.cc tensorflow/core/kernels/crop_and_resize_op.cc @@ -163,6 +164,7 @@ tensorflow/core/kernels/pack_op.cc tensorflow/core/kernels/pad_op.cc tensorflow/core/kernels/padding_fifo_queue.cc tensorflow/core/kernels/padding_fifo_queue_op.cc +tensorflow/core/kernels/pooling_ops_3d.cc tensorflow/core/kernels/pooling_ops_common.cc tensorflow/core/kernels/population_count_op.cc tensorflow/core/kernels/quantization_utils.cc @@ -248,6 +250,7 @@ tensorflow/core/kernels/spectrogram_op.cc tensorflow/core/kernels/split_lib_cpu.cc tensorflow/core/kernels/split_op.cc tensorflow/core/kernels/split_v_op.cc +tensorflow/core/kernels/stack.cc tensorflow/core/kernels/stack_ops.cc tensorflow/core/kernels/strided_slice_op.cc tensorflow/core/kernels/strided_slice_op_inst_0.cc diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index b313024e2852caf2385454771b289ad0162cc463..45a60d79482787df4564ae3360f8252af93c7a26 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -51,7 +51,7 @@ The pruning library allows for specification of the following hyper parameters: | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | -| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | +| threshold_decay | float | 0.0 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index d2b811641764df05c66654dfcb044fa7e78853a5..f6b4373edd0544555dd16a373802d2feb5d674b1 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -204,7 +204,7 @@ def get_pruning_hparams(): begin_pruning_step=0, end_pruning_step=-1, weight_sparsity_map=[''], - threshold_decay=0.9, + threshold_decay=0.0, pruning_frequency=10, nbins=256, block_height=1, @@ -456,13 +456,14 @@ class Pruning(object): pool_window = [self._block_dim[0], self._block_dim[1]] pool_fn = pruning_utils.factorized_pool - + squeeze_axis = None if not self._spec.use_tpu: pool_fn = nn_ops.pool abs_weights = array_ops.reshape( abs_weights, [1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1]) + squeeze_axis = [0, 3] pooled_weights = pool_fn( abs_weights, @@ -473,7 +474,7 @@ class Pruning(object): name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: - pooled_weights = array_ops.squeeze(pooled_weights) + pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index 91b0bb7f6003c047e4dcf342695f433edbc11614..14fc51229ab53a77e8089040e8a8576babd0fafd 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -188,7 +188,6 @@ def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: values = ops.convert_to_tensor(values, name='values') values = array_ops.reshape(values, [-1]) - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -250,7 +249,6 @@ def compute_cdf(values, value_range, **kwargs): name = kwargs.get('name', None) with ops.name_scope(name, 'cdf', [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') - value_range = ops.convert_to_tensor(value_range, name='value_range') nbins_float = np.float32(nbins) # Map tensor values that fall within value_range to [0, 1]. @@ -336,7 +334,7 @@ def factorized_pool(input_tensor, padding=padding) return array_ops.squeeze( - array_ops.transpose(width_pooling, perm=[0, 1, 3, 2])) + array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]), axis=[0, 1]) def determine_partitioned_axis(partitioned_variable): diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 0aca843497611552d922715514118cac003c29b2..d6f2bfcb6c2e2beda912eb538d8a4a0a17b486b3 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -85,8 +85,28 @@ class PruningUtilsTest(test.TestCase): @parameterized.named_parameters( - ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]), - ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1])) + ("Input_32x32_block_1x1", [32, 32], [1, 1]), + # block size 6x6 + ("Input_3x3_block_6x6", [3, 3], [6, 6]), + ("Input_32x32_block_6x6", [32, 32], [6, 6]), + ("Input_2x32_block_6x6", [2, 32], [6, 6]), + ("Input_32x2_block_6x6", [32, 2], [6, 6]), + ("Input_30x30_block_6x6", [30, 30], [6, 6]), + # block size 4x4 + ("Input_32x32_block_4x4", [32, 32], [4, 4]), + ("Input_2x32_block_4x4", [2, 32], [4, 4]), + ("Input_32x2_block_4x4", [32, 2], [4, 4]), + ("Input_30x30_block_4x4", [30, 30], [4, 4]), + # block size 1x4 + ("Input_32x32_block_1x4", [32, 32], [1, 4]), + ("Input_2x32_block_1x4", [2, 32], [1, 4]), + ("Input_32x2_block_1x4", [32, 2], [1, 4]), + ("Input_30x30_block_1x4", [30, 30], [1, 4]), + # block size 4x1 + ("Input_32x32_block_4x1", [32, 32], [4, 1]), + ("Input_2x32_block_4x1", [2, 32], [4, 1]), + ("Input_32x2_block_4x1", [32, 2], [4, 1]), + ("Input_30x30_block_4x1", [30, 30], [4, 1])) class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): def _compare_pooling_methods(self, weights, pooling_kwargs): @@ -97,9 +117,11 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): array_ops.reshape( weights, [1, weights.get_shape()[0], - weights.get_shape()[1], 1]), **pooling_kwargs)) + weights.get_shape()[1], 1]), **pooling_kwargs), + axis=[0, 3]) pooled_weights_factorized_pool = pruning_utils.factorized_pool( weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), pooled_weights_factorized_pool.eval()) @@ -113,8 +135,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): [expanded_tensor, kronecker_product]) self.assertAllEqual(expanded_tensor_val, kronecker_product_val) - def testFactorizedAvgPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedAvgPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "AVG", @@ -123,8 +145,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testFactorizedMaxPool(self, window_shape): - weights = variable_scope.get_variable("weights", shape=[1024, 2048]) + def testFactorizedMaxPool(self, input_shape, window_shape): + weights = variable_scope.get_variable("weights", shape=input_shape) pooling_kwargs = { "window_shape": window_shape, "pooling_type": "MAX", @@ -133,8 +155,8 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): } self._compare_pooling_methods(weights, pooling_kwargs) - def testExpandTensor(self, block_dim): - weights = random_ops.random_normal(shape=[1024, 512]) + def testExpandTensor(self, input_shape, block_dim): + weights = random_ops.random_normal(shape=input_shape) self._compare_expand_tensor_with_kronecker_product(weights, block_dim) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index 9ce50bfe1054072b315adecb87f1ba729dfe0d83..b7fd2d2fb9db3eed15eb1cc2934199939790b1c0 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -106,6 +106,32 @@ class MovingAverageOptimizer(optimizer.Optimizer): self._swapped_variable_name_map[v_avg.op.name] = v.op.name return control_flow_ops.group(train_op, ma_op, name='train_with_avg') + def _find_swapped_variable(self, v_name_to_tensor, v_name, tensor): + """Returns name of swapped variable for given tensor. + + Args: + v_name_to_tensor: Mapping from variable names to tensors. + v_name: name of the variable for which swapped variable should be returned + tensor: Tensor which correspond to variable for which swapped variable + should be returned. + + Returns: + Tensor which correspond to swapped variable. + + Raises: + ValueError: If swapped variable could not be found in v_name_to_tensor. + """ + swapped_v_name = self._swapped_variable_name_map.get(v_name, None) + if swapped_v_name is None: + return tensor + else: + if swapped_v_name in v_name_to_tensor: + return v_name_to_tensor[swapped_v_name] + else: + raise ValueError( + ('Variable to swap %s is not part of variables to save. ' + 'This breaks MovingAverageOptimizer.') % swapped_v_name) + def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. @@ -141,33 +167,33 @@ class MovingAverageOptimizer(optimizer.Optimizer): if not isinstance(var_list, dict): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) - # OpListToDict converts variables to tensors. We make sure we can get - # the unique variable name for normal and resource vaiables. - def get_v_name(tensor): - if tensor.op.type == 'ReadVariableOp': - return tensor.op.inputs[0].op.name - else: - return tensor.op.name - v_name_to_tensor = {} - for tensor in six.itervalues(var_list): - v_name = get_v_name(tensor) - v_name_to_tensor[v_name] = tensor + for k, tensor_or_list in six.iteritems(var_list): + # For each partitioned variable OpListToDict returns list of constituent + # parts instead of single tensor. + if (isinstance(tensor_or_list, list) + or isinstance(tensor_or_list, variables.PartitionedVariable)): + for tensor in tensor_or_list: + v_name = tensor.op.name + v_name_to_tensor[v_name] = tensor + else: + v_name_to_tensor[k] = tensor_or_list # Now swap variables and moving averages swapped_var_list = {} - for k, tensor in six.iteritems(var_list): - v_name = get_v_name(tensor) - swapped_v_name = self._swapped_variable_name_map.get(v_name, None) - tensor_to_save = tensor - if swapped_v_name is not None: - if swapped_v_name in v_name_to_tensor: - tensor_to_save = v_name_to_tensor[swapped_v_name] - else: - raise ValueError( - ('Variable to swap %s is not part of variables to save. ' - 'This breaks MovingAverageOptimizer.') % swapped_v_name) - swapped_var_list[k] = tensor_to_save + for k, tensor_or_list in six.iteritems(var_list): + if isinstance(tensor_or_list, list): + tensor_list_to_save = [] + for tensor in tensor_or_list: + v_name = tensor.op.name + swapped_variable = self._find_swapped_variable(v_name_to_tensor, + v_name, + tensor) + tensor_list_to_save.append(swapped_variable) + swapped_var_list[k] = tensor_list_to_save + else: + swapped_var_list[k] = self._find_swapped_variable( + v_name_to_tensor, k, tensor_or_list) # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index f22e7245285a8b2716645f9789eb5997928a22d2..643403eea6f88bcb33aa96d6539bc9a45a109c6b 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -43,97 +45,171 @@ class MovingAverageOptimizerTest(test.TestCase): # Test that MovingAverageOptimizer works with resource variables. self._helpTestRun(use_resource=True) - def _helpTestRun(self, use_resource=False): + def testRunUsePartitionedVars(self): + # Test that MovingAverageOptimizer works with partitioned variables. + self._helpTestRun(use_partitioned_vars=True) + + def testRunUseResourcePartitionedVars(self): + # Test that MovingAverageOptimizer works with resource and partitioned + # variables. + self._helpTestRun(use_partitioned_vars=True, use_resource=True) + + def _helpTestRun(self, use_resource=False, use_partitioned_vars=False): + # Partitioned variables are represented as a "collection" of partitions. + # To simplify the test and reuse as much code as possible we employ + # following test strategy for partitioned variables. + # + # In the case of non-partitioned variables test runs on variables with + # shape [2]. + # + # In the case of partitioned variables we use shape [4] with two partitions, + # thus each partition has shape [2]. + # For partitioned variables the test is run twice (for loop over + # variable_part_names), first time on the first partition of each variable, + # second time on the second partition of each variable. + variable_part_names = ['part_0', 'part_1'] if use_partitioned_vars else [''] for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.session(graph=ops.Graph()) as sess: - orig_val0 = [1.0, 2.0] - orig_val1 = [3.0, 4.0] - var0 = variable_scope.get_variable( - 'var0', - initializer=constant_op.constant(orig_val0, dtype=dtype), - use_resource=use_resource) - var1 = variable_scope.get_variable( - 'var1', - initializer=constant_op.constant(orig_val1, dtype=dtype), - use_resource=use_resource) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - - opt = moving_average_optimizer.MovingAverageOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=2.0), - average_decay=0.5, - sequential_update=sequential_update) - save_dir = tempfile.mkdtemp( - prefix=os.path.join(self.get_temp_dir(), 'run_1')) - save_path = os.path.join(save_dir, 'model') - update = opt.apply_gradients( - list(six.moves.zip([grads0, grads1], [var0, var1]))) - global_vars = variables.global_variables() - ema_var0 = [ - v for v in global_vars - if v.op.name == 'var0/ExponentialMovingAverage' - ][0] - ema_var1 = [ - v for v in global_vars - if v.op.name == 'var1/ExponentialMovingAverage' - ][0] - perturb = control_flow_ops.group([ - state_ops.assign_add(var0, [1.0, 1.0]), - state_ops.assign_add(var1, [2.0, 2.0]), - state_ops.assign_add(ema_var0, [3.0, 3.0]), - state_ops.assign_add(ema_var1, [4.0, 4.0]) - ]) - - # Test that saver with missing ema variables will fail. - with self.assertRaisesRegexp(ValueError, r'Variable to swap'): - opt.swapping_saver(var_list=[var0]) - - train_saver = opt.swapping_saver() - train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) - inference_saver = saver.Saver() - variables.global_variables_initializer().run() - # Step 1. - update.run() - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the swapping saver save/restore operation is identity. - train_saver.save(sess, save_path) - train_saver.restore(sess, save_path) - self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) - if sequential_update: - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) - # Test that the subset saver saves the EMA variable as well. - if sequential_update: - subset_save_path = save_path + '_subset' - train_saver_subset.save(sess, subset_save_path) - perturb.run() - self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) - self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restoring should only restore var0 and ema_var0. - train_saver_subset.restore(sess, subset_save_path) + for var_part_name in variable_part_names: + with self.session(graph=ops.Graph()) as sess: + orig_val0 = [1.0, 2.0] + orig_val1 = [3.0, 4.0] + grads0 = [0.1, 0.1] + grads1 = [0.01, 0.01] + if use_partitioned_vars: + # Use partitioned variables. + # Create partitioned and duplicate each value used as initial + # value of variables. + partitioner = partitioned_variables.fixed_size_partitioner( + num_shards=2) + orig_val0 = orig_val0 * 2 + orig_val1 = orig_val1 * 2 + grads0 = grads0 * 2 + grads1 = grads1 * 2 + else: + # Regular (non-partitioned) variables. + partitioner = None + var0 = variable_scope.get_variable( + 'var0', + initializer=constant_op.constant(orig_val0, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + var1 = variable_scope.get_variable( + 'var1', + initializer=constant_op.constant(orig_val1, dtype=dtype), + use_resource=use_resource, + partitioner=partitioner) + # Make a fake loss, such that gradient(loss, var0) == grads0 + # and gradient(loss, var1) == grads1 + grads0 = constant_op.constant(grads0, dtype=dtype) + grads1 = constant_op.constant(grads1, dtype=dtype) + loss = (math_ops.reduce_sum(grads0 * var0) + + math_ops.reduce_sum(grads1 * var1)) + + opt = moving_average_optimizer.MovingAverageOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=2.0), + average_decay=0.5, + sequential_update=sequential_update) + save_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'run_1')) + save_path = os.path.join(save_dir, 'model') + + update = opt.minimize(loss) + + # Get variables and their EMAs. In case of partitioned variables + # get proper part of each variable. + def _get_variable(var_name, part_name, ema): + """Returns variable of it's moving average by name.""" + matches = [ + v for v in variables.global_variables() + if ((var_name in v.op.name) + and (part_name in v.op.name) + and (('ExponentialMovingAverage' in v.op.name) == ema)) + ] + self.assertEqual(len(matches), 1) + return matches[0] + var0 = _get_variable('var0', var_part_name, ema=False) + var1 = _get_variable('var1', var_part_name, ema=False) + ema_var0 = _get_variable('var0', var_part_name, ema=True) + ema_var1 = _get_variable('var1', var_part_name, ema=True) + + perturb = control_flow_ops.group([ + state_ops.assign_add(var0, [1.0, 1.0]), + state_ops.assign_add(var1, [2.0, 2.0]), + state_ops.assign_add(ema_var0, [3.0, 3.0]), + state_ops.assign_add(ema_var1, [4.0, 4.0]) + ]) + + # Test that saver with missing ema variables will fail. + with self.assertRaisesRegexp(ValueError, r'Variable to swap'): + opt.swapping_saver(var_list=[var0]) + + train_saver = opt.swapping_saver() + train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) + inference_saver = saver.Saver() + variables.global_variables_initializer().run() + # Step 1. + update.run() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) - self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) - self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) - self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) - # Restore back to previous state. + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the swapping saver save/restore operation is identity. + train_saver.save(sess, save_path) train_saver.restore(sess, save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the subset saver saves the EMA variable as well. + if sequential_update: + subset_save_path = save_path + '_subset' + train_saver_subset.save(sess, subset_save_path) + perturb.run() + self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) + self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restoring should only restore var0 and ema_var0. + train_saver_subset.restore(sess, subset_save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restore back to previous state. + train_saver.restore(sess, save_path) - # If updates are parallel, this is not always true after the 1st step. - if sequential_update: + # If updates are parallel, + # this is not always true after the 1st step. + if sequential_update: + # Test that the normal saver will have the averaged variables. + # We test that the average values are between the original value + # and the most recent variable values (since they are an average + # of the two). + val0 = var0.eval() + val1 = var1.eval() + train_saver.save(sess, save_path) + inference_saver.restore(sess, save_path) + avg_val0 = var0.eval() + avg_val1 = var1.eval() + for i in six.moves.range(len(val0)): + self.assertLess(val0[i], avg_val0[i]) + self.assertLess(avg_val0[i], orig_val0[i]) + self.assertLess(val1[i], avg_val1[i]) + self.assertLess(avg_val1[i], orig_val1[i]) + train_saver.restore(sess, save_path) + # Step 2. + update.run() # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value - # and the most recent variable values (since they are an average - # of the two). + # We test that the average values are between the original value and + # the most recent variable values (since they are an average of the + # two). val0 = var0.eval() val1 = var1.eval() + self.assertAllCloseAccordingToType([0.6, 1.6], val0) + self.assertAllCloseAccordingToType([2.96, 3.96], val1) train_saver.save(sess, save_path) inference_saver.restore(sess, save_path) avg_val0 = var0.eval() @@ -143,26 +219,6 @@ class MovingAverageOptimizerTest(test.TestCase): self.assertLess(avg_val0[i], orig_val0[i]) self.assertLess(val1[i], avg_val1[i]) self.assertLess(avg_val1[i], orig_val1[i]) - train_saver.restore(sess, save_path) - # Step 2. - update.run() - # Test that the normal saver will have the averaged variables. - # We test that the average values are between the original value and - # the most recent variable values (since they are an average of the - # two). - val0 = var0.eval() - val1 = var1.eval() - self.assertAllCloseAccordingToType([0.6, 1.6], val0) - self.assertAllCloseAccordingToType([2.96, 3.96], val1) - train_saver.save(sess, save_path) - inference_saver.restore(sess, save_path) - avg_val0 = var0.eval() - avg_val1 = var1.eval() - for i in six.moves.range(len(val0)): - self.assertLess(val0[i], avg_val0[i]) - self.assertLess(avg_val0[i], orig_val0[i]) - self.assertLess(val1[i], avg_val1[i]) - self.assertLess(avg_val1[i], orig_val1[i]) def testFailWhenSaverCreatedBeforeInitialized(self): with self.cached_session(): diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 44a8890cb107440b79cf8fbbdfcfda503b1c910f..155ff5b3f4f29d4d9c81bb265d19d1b8cce4fef2 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f789c83e005ab7ad7e7caff4ef9ee3c2f57c21fe..467dd86d8fd247a42be2dc47d5bf9872e14da89e 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -790,14 +790,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of replicas (callable-loss case). In this case, # we have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss_value *= 1. / num_replicas + loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas) if var_list is None: var_list = tape.watched_variables() @@ -808,14 +801,7 @@ class OptimizerV2(optimizer_v1.Optimizer): "be a function when eager execution is enabled.") # Scale loss for number of replicas (non-callable-loss case). - if scale_loss_by_num_replicas is None: - scale_loss_by_num_replicas = ( - distribute_lib.get_loss_reduction() == variable_scope - .VariableAggregation.MEAN) - if scale_loss_by_num_replicas: - num_replicas = distribute_ctx.get_distribution_strategy().num_replicas - if num_replicas > 1: - loss *= 1. / num_replicas + loss = self._scale_loss(loss, scale_loss_by_num_replicas) if gate_gradients not in [ optimizer_v1.Optimizer.GATE_NONE, optimizer_v1.Optimizer.GATE_OP, @@ -857,6 +843,20 @@ class OptimizerV2(optimizer_v1.Optimizer): ]) return grads_and_vars + @staticmethod + def _scale_loss(loss_value, scale_loss_by_num_replicas): + """Scale loss for the number of replicas.""" + if scale_loss_by_num_replicas is None: + scale_loss_by_num_replicas = ( + distribute_lib.get_loss_reduction() == variable_scope + .VariableAggregation.MEAN) + if scale_loss_by_num_replicas: + num_replicas = \ + distribute_ctx.get_distribution_strategy().num_replicas_in_sync + if num_replicas > 1: + loss_value *= 1. / num_replicas + return loss_value + def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 057e851aba68c485867c20e964fef750c3158a01..15ae95f13cffa5d1469d737b23f2a83b9e5a694f 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -141,7 +141,7 @@ __global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, // const int gid = batch_id * cell_size * 4 + act_id; const int cid = batch_id * cell_size + act_id; - Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_logistic_op sigmoid_op; Eigen::internal::scalar_tanh_op tanh_op; Eigen::scalar_clip_op clip_op; diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 291ff83791c7cded2dccc4719bb12e84f00afa42..f0947fe423f7e6bf84dae468bc36ca11147ac0bb 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,7 +82,6 @@ py_library( name = "keras_saved_model", srcs = ["python/saved_model/keras_saved_model.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], visibility = ["//visibility:public"], deps = [ "//tensorflow/python:array_ops", @@ -103,7 +102,7 @@ py_test( size = "medium", srcs = ["python/saved_model/keras_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], + tags = ["no_windows"], deps = [ ":keras_saved_model", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py index 4970ebc31992c28b6c352e9feb9b41f853bebe60..a65b2ce466111c33d0092b7018537573708de2d0 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -345,21 +345,22 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): inputs, outputs = load_model(sess, output_path, model_fn_lib.ModeKeys.EVAL) - sess.run(outputs['metrics/mae/update_op'], { - inputs[input_name]: input_arr, - inputs[target_name]: target_arr - }) + # First obtain the loss and predictions, and run the metric update op by + # feeding in the inputs and targets. + loss, predictions, _ = sess.run( + (outputs['loss'], outputs['predictions/' + output_name], + outputs['metrics/mae/update_op']), + {inputs[input_name]: input_arr, inputs[target_name]: target_arr}) - eval_results = sess.run(outputs, {inputs[input_name]: input_arr, - inputs[target_name]: target_arr}) + # The metric value should be run after the update op, to ensure that it + # reflects the correct value. + metric_value = sess.run(outputs['metrics/mae/value']) self.assertEqual(int(train_before_export), sess.run(training_module.get_global_step())) - self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05) - self.assertAllClose( - ref_mae, eval_results['metrics/mae/value'], atol=1e-05) - self.assertAllClose( - ref_predict, eval_results['predictions/' + output_name], atol=1e-05) + self.assertAllClose(ref_loss, loss, atol=1e-05) + self.assertAllClose(ref_mae, metric_value, atol=1e-05) + self.assertAllClose(ref_predict, predictions, atol=1e-05) # Load train graph, and check for the train op, and prediction values with session.Session(graph=ops.Graph()) as sess: diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 6bd58c4d322c04d4d14d04678e24a05c0f876208..5e4f130b31483204a111e2f778fa5d0fc4526fea 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -4,129 +4,11 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") -load("//tensorflow:tensorflow.bzl", "py_test") # @unused - py_library( name = "signal_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:spectral_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_library( - name = "test_util", - srcs = ["python/kernel_tests/test_util.py"], + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:tf_optimizer", - "//tensorflow/python:training", - ], -) - -cuda_py_tests( - name = "mel_ops_test", - srcs = ["python/kernel_tests/mel_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - ], -) - -cuda_py_tests( - name = "mfcc_ops_test", - srcs = ["python/kernel_tests/mfcc_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:spectral_ops_test_util", - ], -) - -cuda_py_tests( - name = "reconstruction_ops_test", - srcs = ["python/kernel_tests/reconstruction_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "shape_ops_test", - srcs = ["python/kernel_tests/shape_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "spectral_ops_test", - size = "large", - srcs = ["python/kernel_tests/spectral_ops_test.py"], - additional_deps = [ - ":signal_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops_test_util", - ], - tags = ["nomac"], -) - -cuda_py_tests( - name = "window_ops_test", - srcs = ["python/kernel_tests/window_ops_test.py"], - additional_deps = [ - ":signal_py", - ":test_util", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python/ops/signal", ], ) diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index d088e744346aac0aa8675b95d7b792379fc7b019..d01f5ccf51c132082a419ec7db49045ef8bab725 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -14,6 +14,9 @@ # ============================================================================== """Signal processing operations. +`tf.contrib.signal` has been renamed to `tf.signal`. `tf.contrib.signal` will be +removed in TensorFlow 2.0. + See the [Contrib Signal](https://tensorflow.org/api_guides/python/contrib.signal) guide. @@ -39,18 +42,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix -from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms -from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add -from tensorflow.contrib.signal.python.ops.shape_ops import frame +from tensorflow.python.ops.signal.mel_ops import linear_to_mel_weight_matrix +from tensorflow.python.ops.signal.mfcc_ops import mfccs_from_log_mel_spectrograms +from tensorflow.python.ops.signal.reconstruction_ops import overlap_and_add +from tensorflow.python.ops.signal.shape_ops import frame +from tensorflow.python.ops.signal.spectral_ops import inverse_stft +from tensorflow.python.ops.signal.spectral_ops import inverse_stft_window_fn +from tensorflow.python.ops.signal.spectral_ops import stft +from tensorflow.python.ops.signal.window_ops import hamming_window +from tensorflow.python.ops.signal.window_ops import hann_window + +from tensorflow.python.util.all_util import remove_undocumented + # `frame` used to be named `frames`, which is a noun and not a verb. # Keep an alias to `frames` for backwards compatibility. -from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft -from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft_window_fn -from tensorflow.contrib.signal.python.ops.spectral_ops import stft -from tensorflow.contrib.signal.python.ops.window_ops import hamming_window -from tensorflow.contrib.signal.python.ops.window_ops import hann_window +frames = frame -from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py index 596c59ead3460aa63eeff44d5a11a4a8c5cde0da..290c16fe3966791ea78986539750caf938a37322 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops # pylint: disable=unused-import @@ -28,10 +30,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_mod # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _model_ops = loader.load_op_library( @@ -88,6 +92,59 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) +class TreeVariable(tracking.TrackableResource): + """A tree model.""" + + def __init__(self, params, tree_config, stats_handle, name, container=None): + self._params = params + self._tree_config = tree_config + self._stats_handle = stats_handle + self._name = name + self._container = container + self._init_op = None + super(TreeVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "tree_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_model_ops.decision_tree_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_variable( + self.resource_handle, + self._tree_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "tree_variable": + functools.partial( + TreeVariableSavable, + params=self._params, + tree_handle=self.resource_handle, + stats_handle=self._stats_handle, + create_op=self._init_op) + } + + def tree_variable(params, tree_config, stats_handle, name, container=None): r"""Creates a tree model and returns a handle to it. @@ -102,18 +159,13 @@ def tree_variable(params, tree_config, stats_handle, name, container=None): A `Tensor` of type mutable `string`. The handle to the tree. """ with ops.name_scope(name, "TreeVariable") as name: - resource_handle = gen_model_ops.decision_tree_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_model_ops.create_tree_variable( - resource_handle, - tree_config, - params=params.serialized_params_proto) - is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle) + tree_var = TreeVariable(params, tree_config, stats_handle, name, container) + resource_handle = tree_var.resource_handle + create_op = tree_var.initializer + is_initialized_op = tree_var.is_initialized() # Adds the variable to the savable list. - saveable = TreeVariableSavable(params, resource_handle, stats_handle, - create_op, - resource_handle.name) + saveable = tree_var._gather_saveables_for_checkpoint()["tree_variable"]( # pylint: disable=protected-access + name=resource_handle.name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py index 44d486edecc4e4f7ba8a9b6d680178298813621b..9184198cd4c8fd2a7609714d094d5ef2b6868658 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.tensor_forest.python.ops import gen_stats_ops # pylint: disable=unused-import from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import finalize_tree @@ -25,10 +27,12 @@ from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import process_in # pylint: enable=unused-import from tensorflow.contrib.util import loader +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.platform import resource_loader from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking _stats_ops = loader.load_op_library( @@ -84,8 +88,58 @@ class FertileStatsVariableSavable(saver.BaseSaverBuilder.SaveableObject): params=self.params.serialized_params_proto) -def fertile_stats_variable(params, stats_config, name, - container=None): +class FertileStatsVariable(tracking.TrackableResource): + """A Fertile stats variable.""" + + def __init__(self, params, stats_config, name, container=None): + self._params = params + self._stats_config = stats_config + self._name = name + self._container = container + self._init_op = None + super(FertileStatsVariable, self).__init__() + self._resource_handle = self.create_resource() + + def create_resource(self): + if context.executing_eagerly(): + # TODO(allenl): This will leak memory due to kernel caching by the + # shared_name attribute value (but is better than the alternative of + # sharing everything by default when executing eagerly; hopefully creating + # tables in a loop is uncommon). + shared_name = "fertile_stats_variable_%d" % (ops.uid(),) + else: + shared_name = self._name + return gen_stats_ops.fertile_stats_resource_handle_op( + self._container, shared_name=shared_name, name=self._name) + + def initialize(self): + return gen_stats_ops.create_fertile_stats_variable( + self.resource_handle, + self._stats_config, + params=self._params.serialized_params_proto) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_stats_ops.fertile_stats_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + """For object-based checkpointing.""" + return { + "fertile_stats_variable": + functools.partial( + FertileStatsVariableSavable, + params=self._params, + stats_handle=self.resource_handle, + create_op=self.initializer) + } + + +def fertile_stats_variable(params, stats_config, name, container=None): r"""Creates a stats object and returns a handle to it. Args: @@ -98,17 +152,15 @@ def fertile_stats_variable(params, stats_config, name, A `Tensor` of type mutable `string`. The handle to the stats. """ with ops.name_scope(name, "FertileStatsVariable") as name: - resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( - container, shared_name=name, name=name) - - create_op = gen_stats_ops.create_fertile_stats_variable( - resource_handle, stats_config, - params=params.serialized_params_proto) - is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( - resource_handle) + fertile_stats_var = FertileStatsVariable(params, stats_config, name, + container) + resource_handle = fertile_stats_var.resource_handle + create_op = fertile_stats_var.initializer + is_initialized_op = fertile_stats_var.is_initialized() # Adds the variable to the savable list. - saveable = FertileStatsVariableSavable(params, resource_handle, create_op, - resource_handle.name) + saveable = ( + fertile_stats_var._gather_saveables_for_checkpoint()[ # pylint: disable=protected-access + "fertile_stats_variable"](name=resource_handle.name)) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) resources.register_resource(resource_handle, create_op, is_initialized_op) return resource_handle diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 1f5591fe2a602517dd30d6ab1772dc2ff7c523ed..26d54eb156ccc8593d82609195caabb5bb929262 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -141,6 +141,7 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { std::vector input_edges; TF_RETURN_IF_ERROR(node->input_edges(&input_edges)); std::vector> input_node_and_ports; + input_node_and_ports.reserve(input_edges.size()); for (const Edge* input_edge : input_edges) { input_node_and_ports.emplace_back(&input_edge->src()->def(), input_edge->src_output()); @@ -923,7 +924,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { converted_segments.push_back(std::move(curr_segment)); if (VLOG_IS_ON(8)) { - string fname = curr_engine.engine_name; + string fname = engine_segments.back().engine_name; StrAppend(&fname, ".pb"); std::fstream f; f.open(fname.c_str(), std::fstream::out | std::fstream::binary); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a6f954391d3f2bd43b187fad0468c0c83d176803..e2988f5f2a8f6164cbe193573b267e6ffeef3284 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -1552,6 +1552,7 @@ tensorflow::Status ConvertPlugin(OpConverterParams* params) { const auto& node_def = params->node_def; // prepare input std::vector all_inputs; + all_inputs.reserve(inputs.size()); for (auto input : inputs) { all_inputs.emplace_back(const_cast(input.tensor())); } diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 67327d32000caea5db75f4d83e5743e8bde70a92..a0a9cb3f31a945a00eb3f6a5fd1402aab9a2df5f 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -246,6 +246,7 @@ py_library( "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", "python/tpu/session_support.py", + "python/tpu/tensor_tracer.py", "python/tpu/topology.py", "python/tpu/tpu.py", "python/tpu/tpu_feed.py", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 38d1c3049ef7185f2f9f448361029d066678cdae..541fbf33a302a4d850422885fdbbc438bd6b9b7b 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -94,13 +94,6 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( - name = "tf_op_stats_proto", - srcs = ["tf_op_stats.proto"], - cc_api_version = 2, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "tpu_profiler_analysis_proto", srcs = ["tpu_profiler_analysis.proto"], diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto deleted file mode 100644 index 1e66801efd4b2a997ed85289b9b1690bb5d07737..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ /dev/null @@ -1,261 +0,0 @@ -// This proto describes the format of tensorflow operation level stats for -// profiling (in tensorboard) purpose. - -syntax = "proto2"; - -package tensorflow.tpu; - -// Result proto for OpMetrics. -message OpMetricsResult { - // True if this OP is executed on the device; False if it is executed on the - // host. - optional bool on_device = 1; - reserved 2; // was uint32 id. - // Name of this OP. - optional string name = 3; - // Rank of this OP. - optional uint64 rank = 4; - // The starting time in cycles of the last instance of this OP executed. - optional double last_starttime_in_cycles = 5; - // The ending time in cycles of the last instance of this OP executed. - optional double last_endtime_in_cycles = 6; - // If this OP (say A), is an immediate child of another OP (say B), this field - // stores the sum of duration in microseconds of A inside B. If A appears more - // than once in B, the duration of all A's appearances will be added together. - // This sum will be reset after the self-time of B is calculated so that it - // can be reused for a new parent OP. - optional double sum_of_duration_in_us_as_children = 7; - // Number of instances that this OP occurred. - optional uint64 occurrences = 8; - // Total time in microseconds spent in this OP (accumulated - // over all of its occurrences). - optional double total_time_in_us = 9; - // Total self time in microseconds spent in this OP - // (accumulated over all of its occurrences). - optional double total_self_time_in_us = 10; - // The total self time as a fraction of sum of all OP's - // total self time on the host. - optional double host_total_self_time_as_fraction_of_all_op_time = 11; - // Cumulative total self time in fraction on the host. - optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = - 12; - // The total self time as a fraction of sum of all OP's - // total self time on the device. - optional double device_total_self_time_as_fraction_of_all_op_time = 13; - // Cumulative total self time in fraction on the device. - optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = - 14; - // Total number of FLOPs incurred by this OP. - optional double total_flops = 15; - // Total number of bytes accessed by this OP. - optional double total_bytes_accessed = 16; - // Total time in microseconds that special hw unit 1 is occupied by this OP. - optional double unit1_occupancy_in_us = 17; - // Total time in microseconds that special hw unit 2 is occupied by this OP. - optional double unit2_occupancy_in_us = 18; - // Total memory stall time in microseconds. - optional double total_memory_stall_in_us = 19; -} - -// Result proto for OpMetricsDb. -message OpMetricsDbResult { - // A bunch of OpMetricsResults. - repeated OpMetricsResult metrics_db = 1; - // The total host infeed-enqueue duration in picoseconds. - optional uint64 total_host_infeed_enq_duration_ps = 2; - // The total of the difference between the start times of two - // consecutive infeed-enqueues (per host) in picoseconds. - optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3; - // The total device time in microseconds. - optional double total_device_time_in_us = 4; - // The total host time in microseconds. - optional double total_host_time_in_us = 5; -} - -// Result proto for StepInfo. -message StepInfoResult { - // The (micro) step number. - optional uint32 step_num = 1; - // The step duration in picoseconds. - optional uint64 duration_ps = 2; - // The infeed duration in picoseconds. - optional uint64 infeed_duration_ps = 3; - // The outfeed duration in picoseconds. - optional uint64 host_outfeed_ps = 8; - // The start time of this step in picoseconds. - optional uint64 begin_ps = 4; - // The waiting time within this step in picoseconds. - optional uint64 wait_duration_ps = 5; - // The unit b outfeed duration in picoseconds. - optional uint64 unit_b_outfeed_ps = 9; - // The time spent on cross-replica-sum in picoseconds. - optional uint64 crs_duration_ps = 6; - // Percentage of unit b time spent on infeed. - optional double unit_b_infeed_percent = 7; -} - -// Result proto for a sequence of steps. -message StepSequenceResult { - // A sequence of StepInfoResults. - repeated StepInfoResult step_sequence = 1; -} - -// Result proto for a StepDatabase. -message StepDatabaseResult { - // A map from core_id to StepSequenceResult. - map step_sequence_per_core = 1; -} - -// Result proto for looping-related metrics. -message LoopingResult { - // The total iteration time in nanoseconds. - optional double iteration_time_ns = 1; - // The total number of iterations. - optional int32 num_iterations = 2; - // The total computation time in nanoseconds. - optional double computation_time_ns = 3; - // The total number of computations. - optional int32 num_computations = 4; -} - -// Result proto for HloExtraInfo. -message HloExtraInfoResult { - // Category of the HLO op given by the compiler. - optional string category = 1; - // The long name of the HLO that includes the dimensions. - optional string long_name = 2; - // The per-TPU-core batch size inferred from this HLO. - optional int64 per_core_batch_size = 3; -} - -// Result proto for HloExtraInfoMap. -message HloExtraInfoMapResult { - // A map from HLO name to HloExtraInfo. - map hlo_extrainfo_map = 1; -} - -// Result proto for host-independent job information. -message HostIndependentJobInfoResult { - // The change-list number of this build. - optional int64 change_list = 1; - // The time of this build. - optional int64 build_time = 2; - // The target of this build. - optional string build_target = 3; -} - -// Result proto for host-dependent job information. -message HostDependentJobInfoResult { - // This ID of the host where the job was run on. - optional string host_id = 1; - // The command line used to run the job. - optional string command_line = 2; - // The start time of the job on this host. - optional int64 start_time = 3; -} - -// Result proto for RunEnvironment (the run environment of a profiling session). -message RunEnvironmentResult { - // Number of hosts used. - optional int32 host_count = 1; - // The type of TPU used. - optional string tpu_type = 2; - // The number of TPU cores used. - optional int32 tpu_core_count = 3; - // The per-TPU-core batch size. - optional int32 per_core_batch_size = 4; - // Host-independent job information. - optional HostIndependentJobInfoResult host_independent_job_info = 5; - // Host-dependent job information. - repeated HostDependentJobInfoResult host_dependent_job_info = 6; - // The number of replicas, corresponds to input parallelism. - // If there is no model parallelism, replica_count = tpu_core_count - optional int32 replica_count = 7; - // The number of cores used for a single replica, e.g. model parallelism. - // If there is no model parallelism, then num_cores_per_replica = 1 - optional int32 num_cores_per_replica = 8; -} - -// The types of host operations that are tracked. -enum HostOp { - // Invalid host op. - kINVALIDHostOp = 0; - // Each of host op type has two parts: - // (1) the stage where the op happens and (2) the op name. - // stage = Input Data Producer, op = Get Next Batch. - kInputDataProducerGetNextBatch = 1; - // stage = Input Data Producer, op = Session Run. - kInputDataProducerSessionRun = 2; - // stage = Input Data Producer, op = Forward Batch. - kInputDataProducerForwardBatch = 3; - // stage = Infeed Thread, op = Get Next Batch. - kInfeedThreadGetNextBatch = 4; - // stage = Infeed Thread, op = Session Run. - kInfeedThreadSessionRun = 5; - // stage = Infeed Thread, op = Forward Batch. - kInfeedThreadForwardBatch = 6; - // stage = Outfeed Thread, op = Get Next Batch. - kOutfeedThreadGetNextBatch = 7; - // stage = Outfeed Thread, op = Session Run. - kOutfeedThreadSessionRun = 8; - // stage = Outfeed Thread, op = Forward Batch. - kOutfeedThreadForwardBatch = 9; -} - -// Result proto for the host ops per TPU step. -message HostOpsPerTpuStep { - // Whether the data in this message is valid. - optional bool valid = 1 [default = false]; - // The current TPU step number. - optional uint32 tpu_step_num = 2; - // The beginning time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_begin_ps = 3; - // The ending time of the current TPU step on the device in picoseconds. - optional uint64 tpu_step_end_ps = 4; - // For each possible host operation, maps to the difference between the TPU - // step number that the host op targets and the current TPU step number. - // The key is HostOp, value is the step difference. - map step_diffs = 5; -} - -message HostOpsDetailsPerCore { - // Map from core id to HostOpsPerTpuStep. - map core_map = 1; -} - -message HostOpsDetailsPerHost { - // Map from hostname to a map from core id to HostOpsPerTpuStep. - map host_map = 1; -} - -// Result proto for the host ops for all TPU steps. -message HostOpsResult { - reserved 1; // (was repeated HostOpsPerTpuStep host_op_sequence) - // A sequence of records with one for each TPU step. Each record - // is a map from hostname to a map from core id to HostOpsPerTpuStep. - repeated HostOpsDetailsPerHost hostops_details = 2; -} - -// Result proto for TfStatsHelper. -message TfOpStats { - // The result for the TF-metric database. - optional OpMetricsDbResult tf_metrics_db = 1; - // The result for the HLO-metric database. - optional OpMetricsDbResult hlo_metrics_db = 2; - // The result for the step database. - optional StepDatabaseResult step_db = 3; - // The result for the looping-related metrics. - optional LoopingResult looping = 4; - // The result for the HloExtraInfoMap. - optional HloExtraInfoMapResult hlo_extrainfo_map = 5; - // Overall matrix unit utilization in percentage. - optional double matrix_unit_utilization_percent = 6; - // The run environment of this profiling session. - optional RunEnvironmentResult run_environment = 7; - // The result for the host operations. - optional HostOpsResult host_ops = 8; - // A map from core ID to name. - map core_id_to_name_map = 9; - // The result for hw unit b stats. - optional bytes unit_b_stats = 10; -} diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index c2e3be03db0e4cca1a664f9e79aa9107384de312..aae1ab1d37a166303883e3a07a7a01efe2feab51 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -154,6 +154,14 @@ message OptimizationParameters { // updates; not present means no limits are applied. ClippingLimits gradient_clipping_limits = 7; + // Amount of weight decay to apply; see weight_decay_optimizers.py for + // details. Almost all optimizers are supported with this option (MDL Adagrad + // Light does not work, and SGD does not behave as expected if it is enabled). + // Although there is no check, users who want weight decay will probably also + // want to enable gradient accumulation as well so that the decay will happen + // once per minibatch. + float weight_decay_factor = 16; + // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to // apply them using the actual optimization algorithm). This feature is diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index c32bd5997c1493594d253650d42ae2215b2862a2..1cf7f9fcf67ec98feb02dd4298a36153e689f2e5 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -164,14 +164,15 @@ class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) + + for l in self._listeners: + l.after_save(session, step) + end_time = time.time() logging.info("Checkpoint actual writing time: (%.3f sec)", end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - for l in self._listeners: - l.before_save(session, step) - if not asynchronous: _save_fn() return diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index ce2c322ff49382308f87b94fd1dcad8c24d3f540..08f58a5f5b89f92502893e222cbca3bd07b2432b 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1184,8 +1184,7 @@ class TPUFunction(object): # pipelined loop. return None, None - if (self.model.uses_learning_phase and - not isinstance(K.learning_phase(), int)): + if not isinstance(K.learning_phase(), int): # Remove the learning_phase flag at the end. We currently hard code the # learning_phase in TPUFunction. assert isinstance(inputs[-1], int), ( @@ -1651,7 +1650,7 @@ class KerasTPUModel(models.Model): self._make_train_function() sample_weights = sample_weights or [] val_sample_weights = val_sample_weights or [] - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + if not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] else: ins = inputs + targets + sample_weights diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..70baea203cc6174bebc7d90646045efae5f2391d --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py @@ -0,0 +1,553 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================== +"""A utility to trace tensor values on TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path +import re + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging + +_TRACER_LOG_PREFIX = ' [>>>TT>>>]' +_DEVICE_TYPE_TPU = 'tpu' +_DEVICE_TYPE_CPU = 'cpu' +_GLOBAL_STEP_OP_NAME = 'GLOBAL-STEP' +_TRACE_MODE_NAN_INF = 'nan-inf' +_TRACE_MODE_PART_TENSOR = 'part-tensor' +_TRACE_MODE_PART_TENSOR_SIZE = 3 +_TRACE_MODE_FULL_TENSOR = 'full-tensor' +_RECORD_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' +_RECORD_SHOULD_NOT_TRACE = 'not-traced-should-not-trace' +_RECORD_FILTERED_OUT = 'not-traced-filtered-out' +_RECORD_SCALAR = 'not-traced-scalar' +_RECORD_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' +_RECORD_GET_TRACED = 'get-traced' +_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' +_MARKER_SECTION_END = '!!!!!!! section-end:' +_SECTION_NAME_CONFIG = 'configuration' +_SECTION_NAME_REASON = 'reason' +_SECTION_NAME_OP_LIST = 'op-list' +_SECTION_NAME_GRAPH = 'graph' +_FIELD_NAME_VERSION = 'version:' +_FIELD_NAME_DEVICE = 'device:' +_FIELD_NAME_TRACE_MODE = 'trace-mode:' +_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' +_FIELD_NAME_NUM_OPS = 'number-of-ops:' +_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' +_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' +_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") +_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') +_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') +_FLAG_NAME_ENABLE = 'enable' +_FLAG_NAME_TRACE_MODE = 'trace_mode' +_FLAG_NAME_INTERESTING_OPS = 'interesting_ops' +_FLAG_NAME_TRACE_FILE = 'trace_file_path' +_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' +_FLAG_NAME_OP_RANGE = 'op_range' +_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') +_OUTPUT_STREAM_ESCAPE = 'file://' +_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' + + +class TensorTracer(object): + """A software construct for tracing tensor values in a TF graph on TPU. + + This utility is disabled by default. It can be enabled by setting + the TENSOR_TRACER_FLAGS env variable as: + export TENSOR_TRACER_FLAGS="--enable=1" + If it is enabled, it will trace the output tensor values of + selected Ops in the graph. It has two outputs: (1) the traces and (2) + a report. The traces are dumped to a specified local file on the TPU + host. The report is printed to the log.info of the TPU job. + By passing options via the env variable, users can change: + (1) the trace mode (e.g., detecting NaN/Inf, printing partial or + full tensor values) + (2) which Ops to be traced (via op.name or op.type) + (3) output trace file path. + """ + + @staticmethod + def _match_next_flag(flags, pos): + """Returns the match for the next TensorTracer flag.""" + + match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) + if match: + return match + match = _FLAG_NO_QUOTE_PAT.match(flags, pos) + return match + + @staticmethod + def print_flag_values(): + """Prints all TensorTracer flags passed via environment variables.""" + + tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR + result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, + tensor_tracer_flags) + result += 'Individual flag value:\n' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + break + flag_name = match.group(1) + flag_value = match.group(2) + result += ' %s: %s\n'%(flag_name, flag_value) + pos = match.end() + result += '\n' + return result + + @staticmethod + def get_flag_value(wanted_flag_name): + """Returns the value of a TensorTracer flags.""" + + tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) + if not tensor_tracer_flags: + return '' + pos = 0 + while True: + match = TensorTracer._match_next_flag(tensor_tracer_flags, pos) + if not match: + return '' + flag_name = match.group(1) + flag_value = match.group(2) + if flag_name == wanted_flag_name: + return flag_value + pos = match.end() + return '' + + @staticmethod + def is_enabled(): + """Returns True if TensorTracer is enabled.""" + + flag_value = TensorTracer.get_flag_value(_FLAG_NAME_ENABLE) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def use_test_undeclared_outputs_dir(): + """Decides the output directory of the trace file. + + Args: + None. + + Returns: + True if the output trace file should be written to the + test-undeclared-outputs-directory defined via an + env variable. + """ + + flag_value = TensorTracer.get_flag_value( + _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) + flag_value = flag_value.lower() + enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] + return enabled + + @staticmethod + def check_device_type(device_type): + """Checks if the given device type is valid.""" + + if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: + raise ValueError('Invalid device_type "%s"'%device_type) + + @staticmethod + def check_trace_mode(trace_mode): + """Checks if the given trace mode is valid.""" + + valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, + _TRACE_MODE_FULL_TENSOR] + if trace_mode not in valid_trace_modes: + raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' + 'Valid trace modes are: %s'%(trace_mode, + valid_trace_modes)) + + @staticmethod + def should_trace(device_type, op): + """Returns True if the given Op should be traced.""" + + if device_type != _DEVICE_TYPE_TPU: + raise ValueError('Non TPU device type is not supported') + if control_flow_util.IsInCond(op): + return False + if op.type in ['Reshape', 'ArgMin', 'ArgMax']: + return False + # pylint: disable=protected-access + return tpu._TPU_REPLICATE_ATTR in op.node_def.attr + # pylint: enable=protected-access + + @staticmethod + def reason(op_idx, details): + """Returns why the Op at op_idx is traced or not.""" + return '%d %s'%(op_idx, details) + + @staticmethod + def topological_sort(g): + """Performs topological sort on the given graph. + + Args: + g: the graph. + + Returns: + A pair where the first element indicates if the topological + sort succeeded (True if there is no cycle found; False if a + cycle is found) and the second element is either the sorted + list of nodes or the cycle of nodes found. + """ + + def visit(op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops): + """Recursively visits all Ops in a graph. + + Args: + op: the current Op being visited. + cycle: a cycle of Ops found. + permanently_marked_ops: the set of Ops that were already visited. + temporarily_marked_ops: the set of Ops that we have visited during + the current descent. + sorted_ops: the list of Ops sorted in topological order. + """ + + if cycle: + return + if op in permanently_marked_ops: + return + if op in temporarily_marked_ops: + cycle = temporarily_marked_ops + return + temporarily_marked_ops.add(op) + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + for consumer_op in out_tensor.consumers(): + visit(consumer_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + # pylint: disable=protected-access + for ctrl_output_op in op._control_outputs: + # pylint: enable=protected-access + visit(ctrl_output_op, cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + temporarily_marked_ops.remove(op) + permanently_marked_ops.add(op) + sorted_ops.insert(0, op) + + graph_cycle = set([]) + sorted_ops = [] + permanently_marked_ops = set([]) + temporarily_marked_ops = set([]) + unsorted_ops = g.get_operations() + for op in unsorted_ops: + visit(op, graph_cycle, permanently_marked_ops, + temporarily_marked_ops, sorted_ops) + if graph_cycle: + return (False, graph_cycle) + else: + assert len(unsorted_ops) == len(sorted_ops) + return (True, sorted_ops) + + def __init__(self): + """Initializes a TensorTracer. + + Sets the various member fields from the flags (if given) or the defaults. + """ + self._version = 'use-outside-compilation' + self._device_type = None + self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) + if not self._trace_mode: + self._trace_mode = _TRACE_MODE_NAN_INF + TensorTracer.check_trace_mode(self._trace_mode) + self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE + self._instrument_records = {} + interesting_ops = TensorTracer.get_flag_value(_FLAG_NAME_INTERESTING_OPS) + self._selected_ops = interesting_ops.split() + self._set_trace_file_path() + self._set_op_range() + self._num_replicas = None + self._replica_id = None + + def _add_replica_id_to_graph(self, num_replicas, result_tensor): + """Adds nodes for computing the replica ID to the graph.""" + + if not num_replicas: + self._replica_id = 'unknown' + return result_tensor + + self._num_replicas = num_replicas + + with ops.control_dependencies(None): + # Uses None as dependency to run outside of TPU graph rewrites. + self._replica_id = tpu_ops.tpu_replicated_input( + list(range(self._num_replicas)), + name='tt_replica_id') + use_replica_id = array_ops.identity(self._replica_id).op + with ops.control_dependencies([use_replica_id]): + # Adds a control dependency from the result_tensor to + # the replica_id to ensure that replica_id will be added to the graph. + return array_ops.identity(result_tensor) + + def _set_trace_file_path(self): + """Sets the path of the output trace file.""" + + self._trace_file_path = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_FILE) + if not self._trace_file_path: + raise ValueError('--%s is not set in the environment variable %s' + %(_FLAG_NAME_TRACE_FILE, _FLAGS_ENV_VAR)) + elif TensorTracer.use_test_undeclared_outputs_dir(): + if os.path.isabs(self._trace_file_path): + raise ValueError('If use_test_undeclared_outputs_dir is set,' + 'trace_file_path cannot be an absolute path (%s)' + %self._trace_file_path) + outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) + self._trace_file_path = os.path.join(outputs_dir, + self._trace_file_path) + + def _set_op_range(self): + """Sets the index range of the Ops that we will consider tracing.""" + + op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) + if not op_range: + self._op_range = (-1, -1) # this means including all ops. + return + match = _OP_RANGE_PAT.match(op_range) + if not match: + self._op_range = (-1, -1) # this means including all ops. + return + self._op_range = (int(match.group(1)), int(match.group(2))) + + def _inside_op_range(self, idx): + """Return True if the given index is inside the selected range.""" + + if idx < self._op_range[0]: + return False + return self._op_range[1] < 0 or idx <= self._op_range[1] + + def _write_report(self, content): + """Writes the given content to the report.""" + + logging.info('%s %s'%(_TRACER_LOG_PREFIX, content)) + + def _is_selected_op(self, op_name): + """Returns True if the Op with op_name is selected to be traced.""" + + if not self._selected_ops: + return True + if op_name in self._selected_ops: + return True + return False + + def _write_config_section(self): + """Writes the config section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) + self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) + self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) + self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) + self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) + + def _write_reason_section(self): + """Writes the reason section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) + for key in sorted(self._instrument_records): + self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) + + def _write_op_list_section(self, op_list): + """Writes the Op-list section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) + self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) + for i in range(0, len(op_list)): + self._write_report('%d "%s" %s\n'%(i, op_list[i].name, op_list[i].type)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) + + def _write_graph_section(self, succeed, sorted_or_cycle): + """Writes the graph section of the report.""" + + self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) + self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, + succeed)) + l = list(sorted_or_cycle) + for i in range(0, len(l)): + self._write_report('%d "%s"\n'%(i, l[i].name)) + self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) + + def _make_tensor_trace_fun(self, op_name, output_idx): + """Makes the tensor tracing function called by outside compilation. + + Args: + op_name: the name of the Op that outputs the tensor to be traced. + output_idx: which output of the Op it is (0 means the first output). + + Returns: + A function to be passed as the first argument to outside compilation. + + Raises: + RuntimeError: If the trace mode is invalid. + """ + + def _print_tensor(op_name, output_idx, num_elements, tensor, output_tensor): + """Prints a tensor value to a file. + + Args: + op_name: the name of the Op that outputs the tensor to be printed. + output_idx: which output of the Op it is (0 means the first output). + num_elements: number of elements to print. + tensor: the tensor needs to be returned. + output_tensor: the tensor needs to be printed. + + Returns: + The same tensor passed via the "tensor" argument. + """ + msg = '"%s:%d" '%(op_name, output_idx) + output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path + print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor), + ' @', self._replica_id, + '\n', output_tensor, + summarize=num_elements, + output_stream=output_stream) + with ops.control_dependencies([print_op]): + return array_ops.identity(tensor).op + + def _detect_nan_inf(tensor): + """Trace function for detecting any NaN/Inf in the tensor.""" + + if tensor.dtype.is_floating: + # Since host can't handle bf16, always convert tensor to f32. + tensor = math_ops.cast(tensor, dtypes.float32) + output_tensor = math_ops.reduce_any( + gen_math_ops.logical_or(gen_math_ops.is_nan(tensor), + gen_math_ops.is_inf(tensor))) + else: + output_tensor = constant_op.constant(0) + return _print_tensor(op_name, output_idx, 1, tensor, output_tensor) + + def _show_global_step(tensor): + """Trace function for printing the global step count.""" + + return _print_tensor(op_name, output_idx, 1, tensor, tensor) + + def _show_part_tensor(tensor): + """Trace function for printing part of the tensor.""" + + return _print_tensor(op_name, output_idx, self._part_tensor_size, + tensor, tensor) + + def _show_full_tensor(tensor): + """Trace function for printing the entire tensor.""" + + return _print_tensor(op_name, output_idx, -1, tensor, tensor) + + if op_name == _GLOBAL_STEP_OP_NAME: + return _show_global_step + if self._trace_mode == _TRACE_MODE_NAN_INF: + return _detect_nan_inf + if self._trace_mode == _TRACE_MODE_PART_TENSOR: + return _show_part_tensor + if self._trace_mode == _TRACE_MODE_FULL_TENSOR: + return _show_full_tensor + + raise RuntimeError('Tensor trace fun for %s is not yet implemented' + %self._trace_mode) + + def trace_tpu(self, graph, result_tensor, num_replicas=None): + """Traces the tensors generated by TPU Ops in a TF graph. + + Args: + graph: the graph of Ops. + result_tensor: a result tensor of evaluating the graph. + num_replicas: number of replicas used on the TPU. + + Returns: + A tuple (result_tensor_copy, tracing_ops), where: + result_tensor_copy: an exact copy of result_tensor + tracing_ops: a list of tracing ops. If this list + is non empty, the caller of this function + should pose control dependencies upon these + Ops so that they will be executed when the + graph is evaluated. + """ + + self._device_type = _DEVICE_TYPE_TPU + TensorTracer.check_device_type(self._device_type) + result_tensor_copy = self._add_replica_id_to_graph(num_replicas, + result_tensor) + self._write_config_section() + tracing_ops = [] + operations = graph.get_operations() + self._write_op_list_section(operations) + # Does the topological sort before adding any nodes to the graph. + (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) + for op_id, op in enumerate(operations): + if not self._inside_op_range(op_id): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_OUTSIDE_OP_RANGE) + continue + if not TensorTracer.should_trace(self._device_type, op): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_SHOULD_NOT_TRACE) + continue + if not self._is_selected_op(op.name): + self._instrument_records[op.name] = TensorTracer.reason( + op_id, _RECORD_FILTERED_OUT) + continue + for i in range(len(op.outputs)): + out_tensor = op.outputs[i] + if not out_tensor.get_shape().is_fully_defined(): + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_DYNAMIC_SHAPE) + continue # cannot trace tensors with dynamic shape. + rank = len(out_tensor.shape) + if rank < 1: + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_SCALAR) + continue # cannot trace scalar. + self._instrument_records[out_tensor.name] = TensorTracer.reason( + op_id, _RECORD_GET_TRACED) + consumers = out_tensor.consumers() + trace_op = tpu.outside_compilation( + self._make_tensor_trace_fun(op.name, i), out_tensor) + if consumers: + for consumer_op in consumers: + # pylint: disable=protected-access + consumer_op._add_control_input(trace_op) + # pylint: enable=protected-access + else: + # if there is no consumer, we will add the control dependence later + # when we add the control dependency to the output operations. + tracing_ops.append(trace_op) + + self._write_reason_section() + self._write_graph_section(succeed, sorted_or_cycle) + + return (result_tensor_copy, tracing_ops) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index b6bb5c6e56c74003ed8ceafe9246fb6a05d928dd..6ae718cc2c9716587849aeee8abcd0a1de82a9ae 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -189,12 +189,13 @@ class Topology(object): def cpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the CPU device attached to a logical core.""" return _tpu_host_device_name( - job, self._topology_tasks[device_coordinates]) + job, self._topology_tasks[tuple(device_coordinates)]) def tpu_device_name_at_coordinates(self, device_coordinates, job=None): """Returns the name of the TPU device assigned to a logical core.""" - return _tpu_device_name(job, self._topology_tasks[device_coordinates], - self._topology_devices[device_coordinates]) + return _tpu_device_name(job, + self._topology_tasks[tuple(device_coordinates)], + self._topology_devices[tuple(device_coordinates)]) @property def num_tasks(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 555ad0f1fdbe36f078c7d2fdcc67571f28c8b723..7cb8c4aa7f14636a9597ec45974ec013ef367414 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -31,6 +31,7 @@ import six from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.tpu.python.tpu import tensor_tracer from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support @@ -108,6 +109,15 @@ ops.register_proto_function( from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access +def _is_iterable(obj): + """A Python 2 and 3 compatible util to check whether `obj` is iterable.""" + try: + iter(obj) + return True + except TypeError: + return False + + def _create_global_step(graph): graph = graph or ops.get_default_graph() if training.get_global_step(graph) is not None: @@ -1317,9 +1327,15 @@ class _ModelFnWrapper(object): captured_training_hooks.capture(estimator_spec.training_hooks) + tracing_ops = [] + if tensor_tracer.TensorTracer.is_enabled(): + tt = tensor_tracer.TensorTracer() + loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss, + self._ctx.num_replicas) + # We must run train_op to update the variables prior to running the # outfeed. - with ops.control_dependencies([train_op]): + with ops.control_dependencies([train_op]+tracing_ops): host_call_outfeed_ops = [] if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access and estimator_spec.host_call is not None): @@ -2250,8 +2266,7 @@ class TPUEstimator(estimator_lib.Estimator): # Only fetching `tpu_tensors_on_cpu` does not trigger # TPU computation and blocks, so we add the control dependency here. control_inputs = ( - tpu_tensors_on_cpu if isinstance(tpu_tensors_on_cpu, - (list, tuple)) else + tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else (tpu_tensors_on_cpu,)) with ops.control_dependencies(control_inputs): new_tensors.append(array_ops.identity(t)) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index afe4c46c8efc59da3da07777ee1fd38be015753d..a701b38d4b3e736a72f20084dbaa6489f1232fb0 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -383,6 +383,7 @@ cc_library( ":lib_platform", ":platform_base", "//tensorflow/core/platform/default/build_config:port", + "@com_google_absl//absl/base", "@snappy", ], ) @@ -1057,6 +1058,7 @@ tf_gen_op_libs( "logging_ops", "manip_ops", "math_ops", + "mkl_nn_ops", "nccl_ops", "nn_ops", "no_op", @@ -1229,7 +1231,7 @@ cc_library( ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + tf_additional_cloud_op_deps(), + ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -1285,7 +1287,9 @@ cc_library( ":framework", ":lib", ":nn_ops_op_lib", - ], + ] + if_mkl([ + ":mkl_nn_ops_op_lib", + ]), alwayslink = 1, ) @@ -1668,6 +1672,7 @@ cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) @@ -2168,6 +2173,7 @@ cc_library( "lib/**/*.cc", "platform/*.cc", "platform/profile_utils/**/*.cc", + ] + [ "framework/resource_handle.cc", "util/env_var.cc", ], @@ -2811,7 +2817,6 @@ tf_cuda_library( ":functional_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:required", - ":core_cpu_impl", ]), alwayslink = 1, ) @@ -2997,6 +3002,16 @@ cc_library( deps = [":lib_internal"], ) +tf_cuda_library( + name = "metrics", + srcs = ["common_runtime/metrics.cc"], + hdrs = ["common_runtime/metrics.h"], + deps = [ + ":lib", + "@com_google_absl//absl/time", + ], +) + tf_cuda_library( name = "direct_session_internal", srcs = ["common_runtime/direct_session.cc"], @@ -3013,10 +3028,12 @@ tf_cuda_library( ":graph", ":lib", ":lib_internal", + ":metrics", ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", + "@com_google_absl//absl/time", ], alwayslink = 1, ) @@ -3048,7 +3065,9 @@ tf_cuda_library( ], copts = tf_copts(), cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()), - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow:internal", + ], deps = [ ":core_cpu_internal", ":lib", diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt index cdaeb5091c7b407addec2811bbf0cb79e61db2d2..bfaf3d2ea5912bf5fde34a91ec51ad42f66b6adb 100644 --- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -4,7 +4,7 @@ op { in_arg { name: "float_values" description: <